• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_vec_mult_f16.c
4  * Description:  Floating-point matrix and vector multiplication
5  *
6  * $Date:        23 April 2021
7  *
8  * $Revision:    V1.9.0
9  *
10  * Target Processor: Cortex-M and Cortex-A cores
11  * -------------------------------------------------------------------- */
12 /*
13  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14  *
15  * SPDX-License-Identifier: Apache-2.0
16  *
17  * Licensed under the Apache License, Version 2.0 (the License); you may
18  * not use this file except in compliance with the License.
19  * You may obtain a copy of the License at
20  *
21  * www.apache.org/licenses/LICENSE-2.0
22  *
23  * Unless required by applicable law or agreed to in writing, software
24  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26  * See the License for the specific language governing permissions and
27  * limitations under the License.
28  */
29 
30 #include "dsp/matrix_functions_f16.h"
31 
32 #if defined(ARM_FLOAT16_SUPPORTED)
33 
34 
35 /**
36  * @ingroup groupMatrix
37  */
38 
39 
40 /**
41  * @addtogroup MatrixVectMult
42  * @{
43  */
44 
45 /**
46  * @brief Floating-point matrix and vector multiplication.
47  * @param[in]       *pSrcMat points to the input matrix structure
48  * @param[in]       *pVec points to input vector
49  * @param[out]      *pDst points to output vector
50  */
51 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
52 
53 #include "arm_helium_utils.h"
54 
arm_mat_vec_mult_f16(const arm_matrix_instance_f16 * pSrcMat,const float16_t * pSrcVec,float16_t * pDstVec)55 void arm_mat_vec_mult_f16(
56     const arm_matrix_instance_f16   *pSrcMat,
57     const float16_t                 *pSrcVec,
58     float16_t                       *pDstVec)
59 {
60     uint32_t         numRows = pSrcMat->numRows;
61     uint32_t         numCols = pSrcMat->numCols;
62     const float16_t *pSrcA = pSrcMat->pData;
63     const float16_t *pInA0;
64     const float16_t *pInA1;
65     float16_t       *px;
66     int32_t          row;
67     uint32_t         blkCnt;           /* loop counters */
68 
69     row = numRows;
70     px = pDstVec;
71 
72     /*
73      * compute 4 rows in parallel
74      */
75     while (row >= 4)
76     {
77         const float16_t     *pInA2, *pInA3;
78         float16_t const    *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
79         f16x8_t            vecIn, acc0, acc1, acc2, acc3;
80         float16_t const     *pSrcVecPtr = pSrcVec;
81 
82         /*
83          * Initialize the pointers to 4 consecutive MatrixA rows
84          */
85         pInA0 = pSrcA;
86         pInA1 = pInA0 + numCols;
87         pInA2 = pInA1 + numCols;
88         pInA3 = pInA2 + numCols;
89         /*
90          * Initialize the vector pointer
91          */
92         pInVec =  pSrcVecPtr;
93         /*
94          * reset accumulators
95          */
96         acc0 = vdupq_n_f16(0.0f);
97         acc1 = vdupq_n_f16(0.0f);
98         acc2 = vdupq_n_f16(0.0f);
99         acc3 = vdupq_n_f16(0.0f);
100 
101         pSrcA0Vec = pInA0;
102         pSrcA1Vec = pInA1;
103         pSrcA2Vec = pInA2;
104         pSrcA3Vec = pInA3;
105 
106         blkCnt = numCols >> 3;
107         while (blkCnt > 0U)
108         {
109             f16x8_t vecA;
110 
111             vecIn = vld1q(pInVec);
112             pInVec += 8;
113             vecA = vld1q(pSrcA0Vec);
114             pSrcA0Vec += 8;
115             acc0 = vfmaq(acc0, vecIn, vecA);
116             vecA = vld1q(pSrcA1Vec);
117             pSrcA1Vec += 8;
118             acc1 = vfmaq(acc1, vecIn, vecA);
119             vecA = vld1q(pSrcA2Vec);
120             pSrcA2Vec += 8;
121             acc2 = vfmaq(acc2, vecIn, vecA);
122             vecA = vld1q(pSrcA3Vec);
123             pSrcA3Vec += 8;
124             acc3 = vfmaq(acc3, vecIn, vecA);
125 
126             blkCnt--;
127         }
128         /*
129          * tail
130          * (will be merged thru tail predication)
131          */
132         blkCnt = numCols & 7;
133         if (blkCnt > 0U)
134         {
135             mve_pred16_t p0 = vctp16q(blkCnt);
136             f16x8_t vecA;
137 
138             vecIn = vldrhq_z_f16(pInVec, p0);
139             vecA = vld1q(pSrcA0Vec);
140             acc0 = vfmaq(acc0, vecIn, vecA);
141             vecA = vld1q(pSrcA1Vec);
142             acc1 = vfmaq(acc1, vecIn, vecA);
143             vecA = vld1q(pSrcA2Vec);
144             acc2 = vfmaq(acc2, vecIn, vecA);
145             vecA = vld1q(pSrcA3Vec);
146             acc3 = vfmaq(acc3, vecIn, vecA);
147         }
148         /*
149          * Sum the partial parts
150          */
151         *px++ = vecAddAcrossF16Mve(acc0);
152         *px++ = vecAddAcrossF16Mve(acc1);
153         *px++ = vecAddAcrossF16Mve(acc2);
154         *px++ = vecAddAcrossF16Mve(acc3);
155 
156         pSrcA += numCols * 4;
157         /*
158          * Decrement the row loop counter
159          */
160         row -= 4;
161     }
162 
163     /*
164      * compute 2 rows in parrallel
165      */
166     if (row >= 2)
167     {
168         float16_t const    *pSrcA0Vec, *pSrcA1Vec, *pInVec;
169         f16x8_t            vecIn, acc0, acc1;
170         float16_t const     *pSrcVecPtr = pSrcVec;
171 
172         /*
173          * Initialize the pointers to 2 consecutive MatrixA rows
174          */
175         pInA0 = pSrcA;
176         pInA1 = pInA0 + numCols;
177         /*
178          * Initialize the vector pointer
179          */
180         pInVec = pSrcVecPtr;
181         /*
182          * reset accumulators
183          */
184         acc0 = vdupq_n_f16(0.0f);
185         acc1 = vdupq_n_f16(0.0f);
186         pSrcA0Vec = pInA0;
187         pSrcA1Vec = pInA1;
188 
189         blkCnt = numCols >> 3;
190         while (blkCnt > 0U)
191         {
192             f16x8_t vecA;
193 
194             vecIn = vld1q(pInVec);
195             pInVec += 8;
196             vecA = vld1q(pSrcA0Vec);
197             pSrcA0Vec += 8;
198             acc0 = vfmaq(acc0, vecIn, vecA);
199             vecA = vld1q(pSrcA1Vec);
200             pSrcA1Vec += 8;
201             acc1 = vfmaq(acc1, vecIn, vecA);
202 
203             blkCnt--;
204         }
205         /*
206          * tail
207          * (will be merged thru tail predication)
208          */
209         blkCnt = numCols & 7;
210         if (blkCnt > 0U)
211         {
212             mve_pred16_t p0 = vctp16q(blkCnt);
213             f16x8_t vecA;
214 
215             vecIn = vldrhq_z_f16(pInVec, p0);
216             vecA = vld1q(pSrcA0Vec);
217             acc0 = vfmaq(acc0, vecIn, vecA);
218             vecA = vld1q(pSrcA1Vec);
219             acc1 = vfmaq(acc1, vecIn, vecA);
220         }
221         /*
222          * Sum the partial parts
223          */
224         *px++ = vecAddAcrossF16Mve(acc0);
225         *px++ = vecAddAcrossF16Mve(acc1);
226 
227         pSrcA += numCols * 2;
228         row -= 2;
229     }
230 
231     if (row >= 1)
232     {
233         f16x8_t             vecIn, acc0;
234         float16_t const     *pSrcA0Vec, *pInVec;
235         float16_t const      *pSrcVecPtr = pSrcVec;
236         /*
237          * Initialize the pointers to last MatrixA row
238          */
239         pInA0 = pSrcA;
240         /*
241          * Initialize the vector pointer
242          */
243         pInVec = pSrcVecPtr;
244         /*
245          * reset accumulators
246          */
247         acc0 = vdupq_n_f16(0.0f);
248 
249         pSrcA0Vec = pInA0;
250 
251         blkCnt = numCols >> 3;
252         while (blkCnt > 0U)
253         {
254             f16x8_t vecA;
255 
256             vecIn = vld1q(pInVec);
257             pInVec += 8;
258             vecA = vld1q(pSrcA0Vec);
259             pSrcA0Vec += 8;
260             acc0 = vfmaq(acc0, vecIn, vecA);
261 
262             blkCnt--;
263         }
264         /*
265          * tail
266          * (will be merged thru tail predication)
267          */
268         blkCnt = numCols & 7;
269         if (blkCnt > 0U)
270         {
271             mve_pred16_t p0 = vctp16q(blkCnt);
272             f16x8_t vecA;
273 
274             vecIn = vldrhq_z_f16(pInVec, p0);
275             vecA = vld1q(pSrcA0Vec);
276             acc0 = vfmaq(acc0, vecIn, vecA);
277         }
278         /*
279          * Sum the partial parts
280          */
281         *px++ = vecAddAcrossF16Mve(acc0);
282     }
283 }
284 #else
arm_mat_vec_mult_f16(const arm_matrix_instance_f16 * pSrcMat,const float16_t * pVec,float16_t * pDst)285 void arm_mat_vec_mult_f16(const arm_matrix_instance_f16 *pSrcMat, const float16_t *pVec, float16_t *pDst)
286 {
287     uint32_t numRows = pSrcMat->numRows;
288     uint32_t numCols = pSrcMat->numCols;
289     const float16_t *pSrcA = pSrcMat->pData;
290     const float16_t *pInA1;      /* input data matrix pointer A of Q31 type */
291     const float16_t *pInA2;      /* input data matrix pointer A of Q31 type */
292     const float16_t *pInA3;      /* input data matrix pointer A of Q31 type */
293     const float16_t *pInA4;      /* input data matrix pointer A of Q31 type */
294     const float16_t *pInVec;     /* input data matrix pointer B of Q31 type */
295     float16_t *px;               /* Temporary output data matrix pointer */
296     uint16_t i, row, colCnt; /* loop counters */
297     float16_t matData, matData2, vecData, vecData2;
298 
299 
300     /* Process 4 rows at a time */
301     row = numRows >> 2;
302     i = 0u;
303     px = pDst;
304 
305     /* The following loop performs the dot-product of each row in pSrcA with the vector */
306     /* row loop */
307     while (row > 0) {
308         /* For every row wise process, the pInVec pointer is set
309          ** to the starting address of the vector */
310         pInVec = pVec;
311 
312         /* Initialize accumulators */
313         float16_t sum1 = 0.0f;
314         float16_t sum2 = 0.0f;
315         float16_t sum3 = 0.0f;
316         float16_t sum4 = 0.0f;
317 
318         /* Loop unrolling: process 2 columns per iteration */
319         colCnt = numCols;
320 
321         /* Initialize pointers to the starting address of the column being processed */
322         pInA1 = pSrcA + i;
323         pInA2 = pInA1 + numCols;
324         pInA3 = pInA2 + numCols;
325         pInA4 = pInA3 + numCols;
326 
327 
328         // Main loop: matrix-vector multiplication
329         while (colCnt > 0u) {
330             // Read 2 values from vector
331             vecData = *(pInVec)++;
332             // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
333             matData = *(pInA1)++;
334             sum1 += matData * vecData;
335             matData = *(pInA2)++;
336             sum2 += matData * vecData;
337             matData = *(pInA3)++;
338             sum3 += matData * vecData;
339             matData = *(pInA4)++;
340             sum4 += matData * vecData;
341 
342             // Decrement the loop counter
343             colCnt--;
344         }
345 
346         /* Saturate and store the result in the destination buffer */
347         *px++ = sum1;
348         *px++ = sum2;
349         *px++ = sum3;
350         *px++ = sum4;
351 
352         i = i + numCols * 4;
353 
354         /* Decrement the row loop counter */
355         row--;
356     }
357 
358     /* process any remaining rows */
359     row = numRows & 3u;
360     while (row > 0) {
361 
362         float16_t sum = 0.0f;
363         pInVec = pVec;
364         pInA1 = pSrcA + i;
365 
366         colCnt = numCols >> 1;
367 
368         while (colCnt > 0) {
369             vecData = *(pInVec)++;
370             vecData2 = *(pInVec)++;
371             matData = *(pInA1)++;
372             matData2 = *(pInA1)++;
373             sum += matData * vecData;
374             sum += matData2 * vecData2;
375             colCnt--;
376         }
377         // process remainder of row
378         colCnt = numCols & 1u;
379         while (colCnt > 0) {
380             sum += *pInA1++ * *pInVec++;
381             colCnt--;
382         }
383 
384         *px++ = sum;
385         i = i + numCols;
386         row--;
387     }
388 }
389 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
390 
391 /**
392  * @} end of MatrixMult group
393  */
394 
395 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
396 
397