1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_vec_mult_f32.c
4 * Description: Floating-point matrix and vector multiplication
5 *
6 * $Date: 23 April 2021
7 *
8 * $Revision: V1.9.0
9 *
10 * Target Processor: Cortex-M and Cortex-A cores
11 * -------------------------------------------------------------------- */
12 /*
13 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14 *
15 * SPDX-License-Identifier: Apache-2.0
16 *
17 * Licensed under the Apache License, Version 2.0 (the License); you may
18 * not use this file except in compliance with the License.
19 * You may obtain a copy of the License at
20 *
21 * www.apache.org/licenses/LICENSE-2.0
22 *
23 * Unless required by applicable law or agreed to in writing, software
24 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 * See the License for the specific language governing permissions and
27 * limitations under the License.
28 */
29
30 #include "dsp/matrix_functions.h"
31
32
33 /**
34 * @ingroup groupMatrix
35 */
36
37 /**
38 * @defgroup MatrixVectMult Matrix Vector Multiplication
39 *
40 * Multiplies a matrix and a vector.
41 *
42 */
43
44 /**
45 * @addtogroup MatrixVectMult
46 * @{
47 */
48
49 /**
50 * @brief Floating-point matrix and vector multiplication.
51 * @param[in] *pSrcMat points to the input matrix structure
52 * @param[in] *pVec points to input vector
53 * @param[out] *pDst points to output vector
54 */
55 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
56
57 #include "arm_helium_utils.h"
58
arm_mat_vec_mult_f32(const arm_matrix_instance_f32 * pSrcMat,const float32_t * pSrcVec,float32_t * pDstVec)59 void arm_mat_vec_mult_f32(
60 const arm_matrix_instance_f32 *pSrcMat,
61 const float32_t *pSrcVec,
62 float32_t *pDstVec)
63 {
64 uint32_t numRows = pSrcMat->numRows;
65 uint32_t numCols = pSrcMat->numCols;
66 const float32_t *pSrcA = pSrcMat->pData;
67 const float32_t *pInA0;
68 const float32_t *pInA1;
69 float32_t *px;
70 int32_t row;
71 uint32_t blkCnt; /* loop counters */
72
73 row = numRows;
74 px = pDstVec;
75
76 /*
77 * compute 4 rows in parallel
78 */
79 while (row >= 4)
80 {
81 const float32_t *pInA2, *pInA3;
82 float32_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
83 f32x4_t vecIn, acc0, acc1, acc2, acc3;
84 float32_t const *pSrcVecPtr = pSrcVec;
85
86 /*
87 * Initialize the pointers to 4 consecutive MatrixA rows
88 */
89 pInA0 = pSrcA;
90 pInA1 = pInA0 + numCols;
91 pInA2 = pInA1 + numCols;
92 pInA3 = pInA2 + numCols;
93 /*
94 * Initialize the vector pointer
95 */
96 pInVec = pSrcVecPtr;
97 /*
98 * reset accumulators
99 */
100 acc0 = vdupq_n_f32(0.0f);
101 acc1 = vdupq_n_f32(0.0f);
102 acc2 = vdupq_n_f32(0.0f);
103 acc3 = vdupq_n_f32(0.0f);
104
105 pSrcA0Vec = pInA0;
106 pSrcA1Vec = pInA1;
107 pSrcA2Vec = pInA2;
108 pSrcA3Vec = pInA3;
109
110 blkCnt = numCols >> 2;
111 while (blkCnt > 0U)
112 {
113 f32x4_t vecA;
114
115 vecIn = vld1q(pInVec);
116 pInVec += 4;
117 vecA = vld1q(pSrcA0Vec);
118 pSrcA0Vec += 4;
119 acc0 = vfmaq(acc0, vecIn, vecA);
120 vecA = vld1q(pSrcA1Vec);
121 pSrcA1Vec += 4;
122 acc1 = vfmaq(acc1, vecIn, vecA);
123 vecA = vld1q(pSrcA2Vec);
124 pSrcA2Vec += 4;
125 acc2 = vfmaq(acc2, vecIn, vecA);
126 vecA = vld1q(pSrcA3Vec);
127 pSrcA3Vec += 4;
128 acc3 = vfmaq(acc3, vecIn, vecA);
129
130 blkCnt--;
131 }
132 /*
133 * tail
134 * (will be merged thru tail predication)
135 */
136 blkCnt = numCols & 3;
137 if (blkCnt > 0U)
138 {
139 mve_pred16_t p0 = vctp32q(blkCnt);
140 f32x4_t vecA;
141
142 vecIn = vldrwq_z_f32(pInVec, p0);
143 vecA = vld1q(pSrcA0Vec);
144 acc0 = vfmaq(acc0, vecIn, vecA);
145 vecA = vld1q(pSrcA1Vec);
146 acc1 = vfmaq(acc1, vecIn, vecA);
147 vecA = vld1q(pSrcA2Vec);
148 acc2 = vfmaq(acc2, vecIn, vecA);
149 vecA = vld1q(pSrcA3Vec);
150 acc3 = vfmaq(acc3, vecIn, vecA);
151 }
152 /*
153 * Sum the partial parts
154 */
155 *px++ = vecAddAcrossF32Mve(acc0);
156 *px++ = vecAddAcrossF32Mve(acc1);
157 *px++ = vecAddAcrossF32Mve(acc2);
158 *px++ = vecAddAcrossF32Mve(acc3);
159
160 pSrcA += numCols * 4;
161 /*
162 * Decrement the row loop counter
163 */
164 row -= 4;
165 }
166
167 /*
168 * compute 2 rows in parrallel
169 */
170 if (row >= 2)
171 {
172 float32_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec;
173 f32x4_t vecIn, acc0, acc1;
174 float32_t const *pSrcVecPtr = pSrcVec;
175
176 /*
177 * Initialize the pointers to 2 consecutive MatrixA rows
178 */
179 pInA0 = pSrcA;
180 pInA1 = pInA0 + numCols;
181 /*
182 * Initialize the vector pointer
183 */
184 pInVec = pSrcVecPtr;
185 /*
186 * reset accumulators
187 */
188 acc0 = vdupq_n_f32(0.0f);
189 acc1 = vdupq_n_f32(0.0f);
190 pSrcA0Vec = pInA0;
191 pSrcA1Vec = pInA1;
192
193 blkCnt = numCols >> 2;
194 while (blkCnt > 0U)
195 {
196 f32x4_t vecA;
197
198 vecIn = vld1q(pInVec);
199 pInVec += 4;
200 vecA = vld1q(pSrcA0Vec);
201 pSrcA0Vec += 4;
202 acc0 = vfmaq(acc0, vecIn, vecA);
203 vecA = vld1q(pSrcA1Vec);
204 pSrcA1Vec += 4;
205 acc1 = vfmaq(acc1, vecIn, vecA);
206
207 blkCnt--;
208 }
209 /*
210 * tail
211 * (will be merged thru tail predication)
212 */
213 blkCnt = numCols & 3;
214 if (blkCnt > 0U)
215 {
216 mve_pred16_t p0 = vctp32q(blkCnt);
217 f32x4_t vecA;
218
219 vecIn = vldrwq_z_f32(pInVec, p0);
220 vecA = vld1q(pSrcA0Vec);
221 acc0 = vfmaq(acc0, vecIn, vecA);
222 vecA = vld1q(pSrcA1Vec);
223 acc1 = vfmaq(acc1, vecIn, vecA);
224 }
225 /*
226 * Sum the partial parts
227 */
228 *px++ = vecAddAcrossF32Mve(acc0);
229 *px++ = vecAddAcrossF32Mve(acc1);
230
231 pSrcA += numCols * 2;
232 row -= 2;
233 }
234
235 if (row >= 1)
236 {
237 f32x4_t vecIn, acc0;
238 float32_t const *pSrcA0Vec, *pInVec;
239 float32_t const *pSrcVecPtr = pSrcVec;
240 /*
241 * Initialize the pointers to last MatrixA row
242 */
243 pInA0 = pSrcA;
244 /*
245 * Initialize the vector pointer
246 */
247 pInVec = pSrcVecPtr;
248 /*
249 * reset accumulators
250 */
251 acc0 = vdupq_n_f32(0.0f);
252
253 pSrcA0Vec = pInA0;
254
255 blkCnt = numCols >> 2;
256 while (blkCnt > 0U)
257 {
258 f32x4_t vecA;
259
260 vecIn = vld1q(pInVec);
261 pInVec += 4;
262 vecA = vld1q(pSrcA0Vec);
263 pSrcA0Vec += 4;
264 acc0 = vfmaq(acc0, vecIn, vecA);
265
266 blkCnt--;
267 }
268 /*
269 * tail
270 * (will be merged thru tail predication)
271 */
272 blkCnt = numCols & 3;
273 if (blkCnt > 0U)
274 {
275 mve_pred16_t p0 = vctp32q(blkCnt);
276 f32x4_t vecA;
277
278 vecIn = vldrwq_z_f32(pInVec, p0);
279 vecA = vld1q(pSrcA0Vec);
280 acc0 = vfmaq(acc0, vecIn, vecA);
281 }
282 /*
283 * Sum the partial parts
284 */
285 *px++ = vecAddAcrossF32Mve(acc0);
286 }
287 }
288 #else
289
arm_mat_vec_mult_f32(const arm_matrix_instance_f32 * pSrcMat,const float32_t * pVec,float32_t * pDst)290 void arm_mat_vec_mult_f32(const arm_matrix_instance_f32 *pSrcMat, const float32_t *pVec, float32_t *pDst)
291 {
292 uint32_t numRows = pSrcMat->numRows;
293 uint32_t numCols = pSrcMat->numCols;
294 const float32_t *pSrcA = pSrcMat->pData;
295 const float32_t *pInA1; /* input data matrix pointer A of Q31 type */
296 const float32_t *pInA2; /* input data matrix pointer A of Q31 type */
297 const float32_t *pInA3; /* input data matrix pointer A of Q31 type */
298 const float32_t *pInA4; /* input data matrix pointer A of Q31 type */
299 const float32_t *pInVec; /* input data matrix pointer B of Q31 type */
300 float32_t *px; /* Temporary output data matrix pointer */
301 uint16_t i, row, colCnt; /* loop counters */
302 float32_t matData, matData2, vecData, vecData2;
303
304
305 /* Process 4 rows at a time */
306 row = numRows >> 2;
307 i = 0u;
308 px = pDst;
309
310 /* The following loop performs the dot-product of each row in pSrcA with the vector */
311 /* row loop */
312 while (row > 0) {
313 /* For every row wise process, the pInVec pointer is set
314 ** to the starting address of the vector */
315 pInVec = pVec;
316
317 /* Initialize accumulators */
318 float32_t sum1 = 0.0f;
319 float32_t sum2 = 0.0f;
320 float32_t sum3 = 0.0f;
321 float32_t sum4 = 0.0f;
322
323 /* Loop unrolling: process 2 columns per iteration */
324 colCnt = numCols;
325
326 /* Initialize pointers to the starting address of the column being processed */
327 pInA1 = pSrcA + i;
328 pInA2 = pInA1 + numCols;
329 pInA3 = pInA2 + numCols;
330 pInA4 = pInA3 + numCols;
331
332
333 // Main loop: matrix-vector multiplication
334 while (colCnt > 0u) {
335 // Read 2 values from vector
336 vecData = *(pInVec)++;
337 // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
338 matData = *(pInA1)++;
339 sum1 += matData * vecData;
340 matData = *(pInA2)++;
341 sum2 += matData * vecData;
342 matData = *(pInA3)++;
343 sum3 += matData * vecData;
344 matData = *(pInA4)++;
345 sum4 += matData * vecData;
346
347 // Decrement the loop counter
348 colCnt--;
349 }
350
351 /* Saturate and store the result in the destination buffer */
352 *px++ = sum1;
353 *px++ = sum2;
354 *px++ = sum3;
355 *px++ = sum4;
356
357 i = i + numCols * 4;
358
359 /* Decrement the row loop counter */
360 row--;
361 }
362
363 /* process any remaining rows */
364 row = numRows & 3u;
365 while (row > 0) {
366
367 float32_t sum = 0.0f;
368 pInVec = pVec;
369 pInA1 = pSrcA + i;
370
371 colCnt = numCols >> 1;
372 while (colCnt > 0) {
373 vecData = *(pInVec)++;
374 vecData2 = *(pInVec)++;
375 matData = *(pInA1)++;
376 matData2 = *(pInA1)++;
377 sum += matData * vecData;
378 sum += matData2 * vecData2;
379 colCnt--;
380 }
381 // process remainder of row
382 colCnt = numCols & 1u;
383
384
385 while (colCnt > 0) {
386 sum += *pInA1++ * *pInVec++;
387 colCnt--;
388 }
389
390 *px++ = sum;
391 i = i + numCols;
392 row--;
393 }
394 }
395 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
396
397 /**
398 * @} end of MatrixMult group
399 */
400