• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_vec_mult_f32.c
4  * Description:  Floating-point matrix and vector multiplication
5  *
6  * $Date:        23 April 2021
7  *
8  * $Revision:    V1.9.0
9  *
10  * Target Processor: Cortex-M and Cortex-A cores
11  * -------------------------------------------------------------------- */
12 /*
13  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14  *
15  * SPDX-License-Identifier: Apache-2.0
16  *
17  * Licensed under the Apache License, Version 2.0 (the License); you may
18  * not use this file except in compliance with the License.
19  * You may obtain a copy of the License at
20  *
21  * www.apache.org/licenses/LICENSE-2.0
22  *
23  * Unless required by applicable law or agreed to in writing, software
24  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26  * See the License for the specific language governing permissions and
27  * limitations under the License.
28  */
29 
30 #include "dsp/matrix_functions.h"
31 
32 
33 /**
34  * @ingroup groupMatrix
35  */
36 
37 /**
38  * @defgroup MatrixVectMult Matrix Vector Multiplication
39  *
40  * Multiplies a matrix and a vector.
41  *
42  */
43 
44 /**
45  * @addtogroup MatrixVectMult
46  * @{
47  */
48 
49 /**
50  * @brief Floating-point matrix and vector multiplication.
51  * @param[in]       *pSrcMat points to the input matrix structure
52  * @param[in]       *pVec points to input vector
53  * @param[out]      *pDst points to output vector
54  */
55 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
56 
57 #include "arm_helium_utils.h"
58 
arm_mat_vec_mult_f32(const arm_matrix_instance_f32 * pSrcMat,const float32_t * pSrcVec,float32_t * pDstVec)59 void arm_mat_vec_mult_f32(
60     const arm_matrix_instance_f32   *pSrcMat,
61     const float32_t                 *pSrcVec,
62     float32_t                       *pDstVec)
63 {
64     uint32_t         numRows = pSrcMat->numRows;
65     uint32_t         numCols = pSrcMat->numCols;
66     const float32_t *pSrcA = pSrcMat->pData;
67     const float32_t *pInA0;
68     const float32_t *pInA1;
69     float32_t       *px;
70     int32_t          row;
71     uint32_t         blkCnt;           /* loop counters */
72 
73     row = numRows;
74     px = pDstVec;
75 
76     /*
77      * compute 4 rows in parallel
78      */
79     while (row >= 4)
80     {
81         const float32_t     *pInA2, *pInA3;
82         float32_t const    *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
83         f32x4_t            vecIn, acc0, acc1, acc2, acc3;
84         float32_t const     *pSrcVecPtr = pSrcVec;
85 
86         /*
87          * Initialize the pointers to 4 consecutive MatrixA rows
88          */
89         pInA0 = pSrcA;
90         pInA1 = pInA0 + numCols;
91         pInA2 = pInA1 + numCols;
92         pInA3 = pInA2 + numCols;
93         /*
94          * Initialize the vector pointer
95          */
96         pInVec =  pSrcVecPtr;
97         /*
98          * reset accumulators
99          */
100         acc0 = vdupq_n_f32(0.0f);
101         acc1 = vdupq_n_f32(0.0f);
102         acc2 = vdupq_n_f32(0.0f);
103         acc3 = vdupq_n_f32(0.0f);
104 
105         pSrcA0Vec = pInA0;
106         pSrcA1Vec = pInA1;
107         pSrcA2Vec = pInA2;
108         pSrcA3Vec = pInA3;
109 
110         blkCnt = numCols >> 2;
111         while (blkCnt > 0U)
112         {
113             f32x4_t vecA;
114 
115             vecIn = vld1q(pInVec);
116             pInVec += 4;
117             vecA = vld1q(pSrcA0Vec);
118             pSrcA0Vec += 4;
119             acc0 = vfmaq(acc0, vecIn, vecA);
120             vecA = vld1q(pSrcA1Vec);
121             pSrcA1Vec += 4;
122             acc1 = vfmaq(acc1, vecIn, vecA);
123             vecA = vld1q(pSrcA2Vec);
124             pSrcA2Vec += 4;
125             acc2 = vfmaq(acc2, vecIn, vecA);
126             vecA = vld1q(pSrcA3Vec);
127             pSrcA3Vec += 4;
128             acc3 = vfmaq(acc3, vecIn, vecA);
129 
130             blkCnt--;
131         }
132         /*
133          * tail
134          * (will be merged thru tail predication)
135          */
136         blkCnt = numCols & 3;
137         if (blkCnt > 0U)
138         {
139             mve_pred16_t p0 = vctp32q(blkCnt);
140             f32x4_t vecA;
141 
142             vecIn = vldrwq_z_f32(pInVec, p0);
143             vecA = vld1q(pSrcA0Vec);
144             acc0 = vfmaq(acc0, vecIn, vecA);
145             vecA = vld1q(pSrcA1Vec);
146             acc1 = vfmaq(acc1, vecIn, vecA);
147             vecA = vld1q(pSrcA2Vec);
148             acc2 = vfmaq(acc2, vecIn, vecA);
149             vecA = vld1q(pSrcA3Vec);
150             acc3 = vfmaq(acc3, vecIn, vecA);
151         }
152         /*
153          * Sum the partial parts
154          */
155         *px++ = vecAddAcrossF32Mve(acc0);
156         *px++ = vecAddAcrossF32Mve(acc1);
157         *px++ = vecAddAcrossF32Mve(acc2);
158         *px++ = vecAddAcrossF32Mve(acc3);
159 
160         pSrcA += numCols * 4;
161         /*
162          * Decrement the row loop counter
163          */
164         row -= 4;
165     }
166 
167     /*
168      * compute 2 rows in parrallel
169      */
170     if (row >= 2)
171     {
172         float32_t const    *pSrcA0Vec, *pSrcA1Vec, *pInVec;
173         f32x4_t            vecIn, acc0, acc1;
174         float32_t const     *pSrcVecPtr = pSrcVec;
175 
176         /*
177          * Initialize the pointers to 2 consecutive MatrixA rows
178          */
179         pInA0 = pSrcA;
180         pInA1 = pInA0 + numCols;
181         /*
182          * Initialize the vector pointer
183          */
184         pInVec = pSrcVecPtr;
185         /*
186          * reset accumulators
187          */
188         acc0 = vdupq_n_f32(0.0f);
189         acc1 = vdupq_n_f32(0.0f);
190         pSrcA0Vec = pInA0;
191         pSrcA1Vec = pInA1;
192 
193         blkCnt = numCols >> 2;
194         while (blkCnt > 0U)
195         {
196             f32x4_t vecA;
197 
198             vecIn = vld1q(pInVec);
199             pInVec += 4;
200             vecA = vld1q(pSrcA0Vec);
201             pSrcA0Vec += 4;
202             acc0 = vfmaq(acc0, vecIn, vecA);
203             vecA = vld1q(pSrcA1Vec);
204             pSrcA1Vec += 4;
205             acc1 = vfmaq(acc1, vecIn, vecA);
206 
207             blkCnt--;
208         }
209         /*
210          * tail
211          * (will be merged thru tail predication)
212          */
213         blkCnt = numCols & 3;
214         if (blkCnt > 0U)
215         {
216             mve_pred16_t p0 = vctp32q(blkCnt);
217             f32x4_t vecA;
218 
219             vecIn = vldrwq_z_f32(pInVec, p0);
220             vecA = vld1q(pSrcA0Vec);
221             acc0 = vfmaq(acc0, vecIn, vecA);
222             vecA = vld1q(pSrcA1Vec);
223             acc1 = vfmaq(acc1, vecIn, vecA);
224         }
225         /*
226          * Sum the partial parts
227          */
228         *px++ = vecAddAcrossF32Mve(acc0);
229         *px++ = vecAddAcrossF32Mve(acc1);
230 
231         pSrcA += numCols * 2;
232         row -= 2;
233     }
234 
235     if (row >= 1)
236     {
237         f32x4_t             vecIn, acc0;
238         float32_t const     *pSrcA0Vec, *pInVec;
239         float32_t const      *pSrcVecPtr = pSrcVec;
240         /*
241          * Initialize the pointers to last MatrixA row
242          */
243         pInA0 = pSrcA;
244         /*
245          * Initialize the vector pointer
246          */
247         pInVec = pSrcVecPtr;
248         /*
249          * reset accumulators
250          */
251         acc0 = vdupq_n_f32(0.0f);
252 
253         pSrcA0Vec = pInA0;
254 
255         blkCnt = numCols >> 2;
256         while (blkCnt > 0U)
257         {
258             f32x4_t vecA;
259 
260             vecIn = vld1q(pInVec);
261             pInVec += 4;
262             vecA = vld1q(pSrcA0Vec);
263             pSrcA0Vec += 4;
264             acc0 = vfmaq(acc0, vecIn, vecA);
265 
266             blkCnt--;
267         }
268         /*
269          * tail
270          * (will be merged thru tail predication)
271          */
272         blkCnt = numCols & 3;
273         if (blkCnt > 0U)
274         {
275             mve_pred16_t p0 = vctp32q(blkCnt);
276             f32x4_t vecA;
277 
278             vecIn = vldrwq_z_f32(pInVec, p0);
279             vecA = vld1q(pSrcA0Vec);
280             acc0 = vfmaq(acc0, vecIn, vecA);
281         }
282         /*
283          * Sum the partial parts
284          */
285         *px++ = vecAddAcrossF32Mve(acc0);
286     }
287 }
288 #else
289 
arm_mat_vec_mult_f32(const arm_matrix_instance_f32 * pSrcMat,const float32_t * pVec,float32_t * pDst)290 void arm_mat_vec_mult_f32(const arm_matrix_instance_f32 *pSrcMat, const float32_t *pVec, float32_t *pDst)
291 {
292     uint32_t numRows = pSrcMat->numRows;
293     uint32_t numCols = pSrcMat->numCols;
294     const float32_t *pSrcA = pSrcMat->pData;
295     const float32_t *pInA1;      /* input data matrix pointer A of Q31 type */
296     const float32_t *pInA2;      /* input data matrix pointer A of Q31 type */
297     const float32_t *pInA3;      /* input data matrix pointer A of Q31 type */
298     const float32_t *pInA4;      /* input data matrix pointer A of Q31 type */
299     const float32_t *pInVec;     /* input data matrix pointer B of Q31 type */
300     float32_t *px;               /* Temporary output data matrix pointer */
301     uint16_t i, row, colCnt; /* loop counters */
302     float32_t matData, matData2, vecData, vecData2;
303 
304 
305     /* Process 4 rows at a time */
306     row = numRows >> 2;
307     i = 0u;
308     px = pDst;
309 
310     /* The following loop performs the dot-product of each row in pSrcA with the vector */
311     /* row loop */
312     while (row > 0) {
313         /* For every row wise process, the pInVec pointer is set
314          ** to the starting address of the vector */
315         pInVec = pVec;
316 
317         /* Initialize accumulators */
318         float32_t sum1 = 0.0f;
319         float32_t sum2 = 0.0f;
320         float32_t sum3 = 0.0f;
321         float32_t sum4 = 0.0f;
322 
323         /* Loop unrolling: process 2 columns per iteration */
324         colCnt = numCols;
325 
326         /* Initialize pointers to the starting address of the column being processed */
327         pInA1 = pSrcA + i;
328         pInA2 = pInA1 + numCols;
329         pInA3 = pInA2 + numCols;
330         pInA4 = pInA3 + numCols;
331 
332 
333         // Main loop: matrix-vector multiplication
334         while (colCnt > 0u) {
335             // Read 2 values from vector
336             vecData = *(pInVec)++;
337             // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
338             matData = *(pInA1)++;
339             sum1 += matData * vecData;
340             matData = *(pInA2)++;
341             sum2 += matData * vecData;
342             matData = *(pInA3)++;
343             sum3 += matData * vecData;
344             matData = *(pInA4)++;
345             sum4 += matData * vecData;
346 
347             // Decrement the loop counter
348             colCnt--;
349         }
350 
351         /* Saturate and store the result in the destination buffer */
352         *px++ = sum1;
353         *px++ = sum2;
354         *px++ = sum3;
355         *px++ = sum4;
356 
357         i = i + numCols * 4;
358 
359         /* Decrement the row loop counter */
360         row--;
361     }
362 
363     /* process any remaining rows */
364     row = numRows & 3u;
365     while (row > 0) {
366 
367         float32_t sum = 0.0f;
368         pInVec = pVec;
369         pInA1 = pSrcA + i;
370 
371         colCnt = numCols >> 1;
372         while (colCnt > 0) {
373             vecData = *(pInVec)++;
374             vecData2 = *(pInVec)++;
375             matData = *(pInA1)++;
376             matData2 = *(pInA1)++;
377             sum += matData * vecData;
378             sum += matData2 * vecData2;
379             colCnt--;
380         }
381         // process remainder of row
382         colCnt = numCols & 1u;
383 
384 
385         while (colCnt > 0) {
386             sum += *pInA1++ * *pInVec++;
387             colCnt--;
388         }
389 
390         *px++ = sum;
391         i = i + numCols;
392         row--;
393     }
394 }
395 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
396 
397 /**
398  * @} end of MatrixMult group
399  */
400