1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_vec_mult_f16.c
4 * Description: Floating-point matrix and vector multiplication
5 *
6 * $Date: 23 April 2021
7 *
8 * $Revision: V1.9.0
9 *
10 * Target Processor: Cortex-M and Cortex-A cores
11 * -------------------------------------------------------------------- */
12 /*
13 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14 *
15 * SPDX-License-Identifier: Apache-2.0
16 *
17 * Licensed under the Apache License, Version 2.0 (the License); you may
18 * not use this file except in compliance with the License.
19 * You may obtain a copy of the License at
20 *
21 * www.apache.org/licenses/LICENSE-2.0
22 *
23 * Unless required by applicable law or agreed to in writing, software
24 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 * See the License for the specific language governing permissions and
27 * limitations under the License.
28 */
29
30 #include "dsp/matrix_functions_f16.h"
31
32 #if defined(ARM_FLOAT16_SUPPORTED)
33
34
35 /**
36 * @ingroup groupMatrix
37 */
38
39
40 /**
41 * @addtogroup MatrixVectMult
42 * @{
43 */
44
45 /**
46 * @brief Floating-point matrix and vector multiplication.
47 * @param[in] *pSrcMat points to the input matrix structure
48 * @param[in] *pVec points to input vector
49 * @param[out] *pDst points to output vector
50 */
51 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
52
53 #include "arm_helium_utils.h"
54
arm_mat_vec_mult_f16(const arm_matrix_instance_f16 * pSrcMat,const float16_t * pSrcVec,float16_t * pDstVec)55 void arm_mat_vec_mult_f16(
56 const arm_matrix_instance_f16 *pSrcMat,
57 const float16_t *pSrcVec,
58 float16_t *pDstVec)
59 {
60 uint32_t numRows = pSrcMat->numRows;
61 uint32_t numCols = pSrcMat->numCols;
62 const float16_t *pSrcA = pSrcMat->pData;
63 const float16_t *pInA0;
64 const float16_t *pInA1;
65 float16_t *px;
66 int32_t row;
67 uint32_t blkCnt; /* loop counters */
68
69 row = numRows;
70 px = pDstVec;
71
72 /*
73 * compute 4 rows in parallel
74 */
75 while (row >= 4)
76 {
77 const float16_t *pInA2, *pInA3;
78 float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
79 f16x8_t vecIn, acc0, acc1, acc2, acc3;
80 float16_t const *pSrcVecPtr = pSrcVec;
81
82 /*
83 * Initialize the pointers to 4 consecutive MatrixA rows
84 */
85 pInA0 = pSrcA;
86 pInA1 = pInA0 + numCols;
87 pInA2 = pInA1 + numCols;
88 pInA3 = pInA2 + numCols;
89 /*
90 * Initialize the vector pointer
91 */
92 pInVec = pSrcVecPtr;
93 /*
94 * reset accumulators
95 */
96 acc0 = vdupq_n_f16(0.0f);
97 acc1 = vdupq_n_f16(0.0f);
98 acc2 = vdupq_n_f16(0.0f);
99 acc3 = vdupq_n_f16(0.0f);
100
101 pSrcA0Vec = pInA0;
102 pSrcA1Vec = pInA1;
103 pSrcA2Vec = pInA2;
104 pSrcA3Vec = pInA3;
105
106 blkCnt = numCols >> 3;
107 while (blkCnt > 0U)
108 {
109 f16x8_t vecA;
110
111 vecIn = vld1q(pInVec);
112 pInVec += 8;
113 vecA = vld1q(pSrcA0Vec);
114 pSrcA0Vec += 8;
115 acc0 = vfmaq(acc0, vecIn, vecA);
116 vecA = vld1q(pSrcA1Vec);
117 pSrcA1Vec += 8;
118 acc1 = vfmaq(acc1, vecIn, vecA);
119 vecA = vld1q(pSrcA2Vec);
120 pSrcA2Vec += 8;
121 acc2 = vfmaq(acc2, vecIn, vecA);
122 vecA = vld1q(pSrcA3Vec);
123 pSrcA3Vec += 8;
124 acc3 = vfmaq(acc3, vecIn, vecA);
125
126 blkCnt--;
127 }
128 /*
129 * tail
130 * (will be merged thru tail predication)
131 */
132 blkCnt = numCols & 7;
133 if (blkCnt > 0U)
134 {
135 mve_pred16_t p0 = vctp16q(blkCnt);
136 f16x8_t vecA;
137
138 vecIn = vldrhq_z_f16(pInVec, p0);
139 vecA = vld1q(pSrcA0Vec);
140 acc0 = vfmaq(acc0, vecIn, vecA);
141 vecA = vld1q(pSrcA1Vec);
142 acc1 = vfmaq(acc1, vecIn, vecA);
143 vecA = vld1q(pSrcA2Vec);
144 acc2 = vfmaq(acc2, vecIn, vecA);
145 vecA = vld1q(pSrcA3Vec);
146 acc3 = vfmaq(acc3, vecIn, vecA);
147 }
148 /*
149 * Sum the partial parts
150 */
151 *px++ = vecAddAcrossF16Mve(acc0);
152 *px++ = vecAddAcrossF16Mve(acc1);
153 *px++ = vecAddAcrossF16Mve(acc2);
154 *px++ = vecAddAcrossF16Mve(acc3);
155
156 pSrcA += numCols * 4;
157 /*
158 * Decrement the row loop counter
159 */
160 row -= 4;
161 }
162
163 /*
164 * compute 2 rows in parrallel
165 */
166 if (row >= 2)
167 {
168 float16_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec;
169 f16x8_t vecIn, acc0, acc1;
170 float16_t const *pSrcVecPtr = pSrcVec;
171
172 /*
173 * Initialize the pointers to 2 consecutive MatrixA rows
174 */
175 pInA0 = pSrcA;
176 pInA1 = pInA0 + numCols;
177 /*
178 * Initialize the vector pointer
179 */
180 pInVec = pSrcVecPtr;
181 /*
182 * reset accumulators
183 */
184 acc0 = vdupq_n_f16(0.0f);
185 acc1 = vdupq_n_f16(0.0f);
186 pSrcA0Vec = pInA0;
187 pSrcA1Vec = pInA1;
188
189 blkCnt = numCols >> 3;
190 while (blkCnt > 0U)
191 {
192 f16x8_t vecA;
193
194 vecIn = vld1q(pInVec);
195 pInVec += 8;
196 vecA = vld1q(pSrcA0Vec);
197 pSrcA0Vec += 8;
198 acc0 = vfmaq(acc0, vecIn, vecA);
199 vecA = vld1q(pSrcA1Vec);
200 pSrcA1Vec += 8;
201 acc1 = vfmaq(acc1, vecIn, vecA);
202
203 blkCnt--;
204 }
205 /*
206 * tail
207 * (will be merged thru tail predication)
208 */
209 blkCnt = numCols & 7;
210 if (blkCnt > 0U)
211 {
212 mve_pred16_t p0 = vctp16q(blkCnt);
213 f16x8_t vecA;
214
215 vecIn = vldrhq_z_f16(pInVec, p0);
216 vecA = vld1q(pSrcA0Vec);
217 acc0 = vfmaq(acc0, vecIn, vecA);
218 vecA = vld1q(pSrcA1Vec);
219 acc1 = vfmaq(acc1, vecIn, vecA);
220 }
221 /*
222 * Sum the partial parts
223 */
224 *px++ = vecAddAcrossF16Mve(acc0);
225 *px++ = vecAddAcrossF16Mve(acc1);
226
227 pSrcA += numCols * 2;
228 row -= 2;
229 }
230
231 if (row >= 1)
232 {
233 f16x8_t vecIn, acc0;
234 float16_t const *pSrcA0Vec, *pInVec;
235 float16_t const *pSrcVecPtr = pSrcVec;
236 /*
237 * Initialize the pointers to last MatrixA row
238 */
239 pInA0 = pSrcA;
240 /*
241 * Initialize the vector pointer
242 */
243 pInVec = pSrcVecPtr;
244 /*
245 * reset accumulators
246 */
247 acc0 = vdupq_n_f16(0.0f);
248
249 pSrcA0Vec = pInA0;
250
251 blkCnt = numCols >> 3;
252 while (blkCnt > 0U)
253 {
254 f16x8_t vecA;
255
256 vecIn = vld1q(pInVec);
257 pInVec += 8;
258 vecA = vld1q(pSrcA0Vec);
259 pSrcA0Vec += 8;
260 acc0 = vfmaq(acc0, vecIn, vecA);
261
262 blkCnt--;
263 }
264 /*
265 * tail
266 * (will be merged thru tail predication)
267 */
268 blkCnt = numCols & 7;
269 if (blkCnt > 0U)
270 {
271 mve_pred16_t p0 = vctp16q(blkCnt);
272 f16x8_t vecA;
273
274 vecIn = vldrhq_z_f16(pInVec, p0);
275 vecA = vld1q(pSrcA0Vec);
276 acc0 = vfmaq(acc0, vecIn, vecA);
277 }
278 /*
279 * Sum the partial parts
280 */
281 *px++ = vecAddAcrossF16Mve(acc0);
282 }
283 }
284 #else
arm_mat_vec_mult_f16(const arm_matrix_instance_f16 * pSrcMat,const float16_t * pVec,float16_t * pDst)285 void arm_mat_vec_mult_f16(const arm_matrix_instance_f16 *pSrcMat, const float16_t *pVec, float16_t *pDst)
286 {
287 uint32_t numRows = pSrcMat->numRows;
288 uint32_t numCols = pSrcMat->numCols;
289 const float16_t *pSrcA = pSrcMat->pData;
290 const float16_t *pInA1; /* input data matrix pointer A of Q31 type */
291 const float16_t *pInA2; /* input data matrix pointer A of Q31 type */
292 const float16_t *pInA3; /* input data matrix pointer A of Q31 type */
293 const float16_t *pInA4; /* input data matrix pointer A of Q31 type */
294 const float16_t *pInVec; /* input data matrix pointer B of Q31 type */
295 float16_t *px; /* Temporary output data matrix pointer */
296 uint16_t i, row, colCnt; /* loop counters */
297 float16_t matData, matData2, vecData, vecData2;
298
299
300 /* Process 4 rows at a time */
301 row = numRows >> 2;
302 i = 0u;
303 px = pDst;
304
305 /* The following loop performs the dot-product of each row in pSrcA with the vector */
306 /* row loop */
307 while (row > 0) {
308 /* For every row wise process, the pInVec pointer is set
309 ** to the starting address of the vector */
310 pInVec = pVec;
311
312 /* Initialize accumulators */
313 float16_t sum1 = 0.0f;
314 float16_t sum2 = 0.0f;
315 float16_t sum3 = 0.0f;
316 float16_t sum4 = 0.0f;
317
318 /* Loop unrolling: process 2 columns per iteration */
319 colCnt = numCols;
320
321 /* Initialize pointers to the starting address of the column being processed */
322 pInA1 = pSrcA + i;
323 pInA2 = pInA1 + numCols;
324 pInA3 = pInA2 + numCols;
325 pInA4 = pInA3 + numCols;
326
327
328 // Main loop: matrix-vector multiplication
329 while (colCnt > 0u) {
330 // Read 2 values from vector
331 vecData = *(pInVec)++;
332 // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
333 matData = *(pInA1)++;
334 sum1 += matData * vecData;
335 matData = *(pInA2)++;
336 sum2 += matData * vecData;
337 matData = *(pInA3)++;
338 sum3 += matData * vecData;
339 matData = *(pInA4)++;
340 sum4 += matData * vecData;
341
342 // Decrement the loop counter
343 colCnt--;
344 }
345
346 /* Saturate and store the result in the destination buffer */
347 *px++ = sum1;
348 *px++ = sum2;
349 *px++ = sum3;
350 *px++ = sum4;
351
352 i = i + numCols * 4;
353
354 /* Decrement the row loop counter */
355 row--;
356 }
357
358 /* process any remaining rows */
359 row = numRows & 3u;
360 while (row > 0) {
361
362 float16_t sum = 0.0f;
363 pInVec = pVec;
364 pInA1 = pSrcA + i;
365
366 colCnt = numCols >> 1;
367
368 while (colCnt > 0) {
369 vecData = *(pInVec)++;
370 vecData2 = *(pInVec)++;
371 matData = *(pInA1)++;
372 matData2 = *(pInA1)++;
373 sum += matData * vecData;
374 sum += matData2 * vecData2;
375 colCnt--;
376 }
377 // process remainder of row
378 colCnt = numCols & 1u;
379 while (colCnt > 0) {
380 sum += *pInA1++ * *pInVec++;
381 colCnt--;
382 }
383
384 *px++ = sum;
385 i = i + numCols;
386 row--;
387 }
388 }
389 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
390
391 /**
392 * @} end of MatrixMult group
393 */
394
395 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
396
397