1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_vec_mult_q7.c
4 * Description: Q7 matrix and vector multiplication
5 *
6 * $Date: 23 April 2021
7 *
8 * $Revision: V1.9.0
9 *
10 * Target Processor: Cortex-M and Cortex-A cores
11 * -------------------------------------------------------------------- */
12 /*
13 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14 *
15 * SPDX-License-Identifier: Apache-2.0
16 *
17 * Licensed under the Apache License, Version 2.0 (the License); you may
18 * not use this file except in compliance with the License.
19 * You may obtain a copy of the License at
20 *
21 * www.apache.org/licenses/LICENSE-2.0
22 *
23 * Unless required by applicable law or agreed to in writing, software
24 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 * See the License for the specific language governing permissions and
27 * limitations under the License.
28 */
29
30 #include "dsp/matrix_functions.h"
31
32 /**
33 * @ingroup groupMatrix
34 */
35
36
37
38 /**
39 * @addtogroup MatrixVectMult
40 * @{
41 */
42
43 /**
44 * @brief Q7 matrix and vector multiplication.
45 * @param[in] *pSrcMat points to the input matrix structure
46 * @param[in] *pVec points to the input vector
47 * @param[out] *pDst points to the output vector
48 */
49 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
50
51 #include "arm_helium_utils.h"
52
arm_mat_vec_mult_q7(const arm_matrix_instance_q7 * pSrcMat,const q7_t * pSrcVec,q7_t * pDstVec)53 void arm_mat_vec_mult_q7(
54 const arm_matrix_instance_q7 * pSrcMat,
55 const q7_t *pSrcVec,
56 q7_t *pDstVec)
57 {
58 const q7_t *pMatSrc = pSrcMat->pData;
59 const q7_t *pMat0, *pMat1;
60 uint32_t numRows = pSrcMat->numRows;
61 uint32_t numCols = pSrcMat->numCols;
62 q7_t *px;
63 int32_t row;
64 uint16_t blkCnt; /* loop counters */
65
66 row = numRows;
67 px = pDstVec;
68
69 /*
70 * compute 4x64-bit accumulators per loop
71 */
72 while (row >= 4)
73 {
74 q7_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pMat3Vec, *pVec;
75 const q7_t *pMat2, *pMat3;
76 q7_t const *pSrcVecPtr = pSrcVec;
77 q31_t acc0, acc1, acc2, acc3;
78 q7x16_t vecMatA0, vecMatA1, vecMatA2, vecMatA3, vecIn;
79
80 pVec = pSrcVec;
81 /*
82 * Initialize the pointer pIn1 to point to the starting address of the column being processed
83 */
84 pMat0 = pMatSrc;
85 pMat1 = pMat0 + numCols;
86 pMat2 = pMat1 + numCols;
87 pMat3 = pMat2 + numCols;
88
89 acc0 = 0L;
90 acc1 = 0L;
91 acc2 = 0L;
92 acc3 = 0L;
93
94 pMat0Vec = pMat0;
95 pMat1Vec = pMat1;
96 pMat2Vec = pMat2;
97 pMat3Vec = pMat3;
98 pVec = pSrcVecPtr;
99
100 blkCnt = numCols >> 4;
101 while (blkCnt > 0U)
102 {
103
104 vecMatA0 = vld1q(pMat0Vec);
105 pMat0Vec += 16;
106 vecMatA1 = vld1q(pMat1Vec);
107 pMat1Vec += 16;
108 vecMatA2 = vld1q(pMat2Vec);
109 pMat2Vec += 16;
110 vecMatA3 = vld1q(pMat3Vec);
111 pMat3Vec += 16;
112 vecIn = vld1q(pVec);
113 pVec += 16;
114
115 acc0 = vmladavaq(acc0, vecIn, vecMatA0);
116 acc1 = vmladavaq(acc1, vecIn, vecMatA1);
117 acc2 = vmladavaq(acc2, vecIn, vecMatA2);
118 acc3 = vmladavaq(acc3, vecIn, vecMatA3);
119
120 blkCnt--;
121 }
122 /*
123 * tail
124 * (will be merged thru tail predication)
125 */
126 blkCnt = numCols & 0xF;
127 if (blkCnt > 0U)
128 {
129 mve_pred16_t p0 = vctp8q(blkCnt);
130
131 vecMatA0 = vld1q(pMat0Vec);
132 vecMatA1 = vld1q(pMat1Vec);
133 vecMatA2 = vld1q(pMat2Vec);
134 vecMatA3 = vld1q(pMat3Vec);
135 vecIn = vldrbq_z_s8(pVec, p0);
136
137 acc0 = vmladavaq(acc0, vecIn, vecMatA0);
138 acc1 = vmladavaq(acc1, vecIn, vecMatA1);
139 acc2 = vmladavaq(acc2, vecIn, vecMatA2);
140 acc3 = vmladavaq(acc3, vecIn, vecMatA3);
141 }
142
143 *px++ = __SSAT(acc0 >> 7, 8);
144 *px++ = __SSAT(acc1 >> 7, 8);
145 *px++ = __SSAT(acc2 >> 7, 8);
146 *px++ = __SSAT(acc3 >> 7, 8);
147
148 pMatSrc += numCols * 4;
149 /*
150 * Decrement the row loop counter
151 */
152 row -= 4;
153 }
154
155 /*
156 * process any remaining rows pair
157 */
158 if (row >= 2)
159 {
160 q7_t const *pMat0Vec, *pMat1Vec, *pVec;
161 q7_t const *pSrcVecPtr = pSrcVec;
162 q31_t acc0, acc1;
163 q7x16_t vecMatA0, vecMatA1, vecIn;
164
165 /*
166 * For every row wise process, the pInVec pointer is set
167 * to the starting address of the vector
168 */
169 pVec = pSrcVec;
170
171 /*
172 * Initialize the pointer pIn1 to point to the starting address of the column being processed
173 */
174 pMat0 = pMatSrc;
175 pMat1 = pMat0 + numCols;
176
177 acc0 = 0;
178 acc1 = 0;
179
180 pMat0Vec = pMat0;
181 pMat1Vec = pMat1;
182 pVec = pSrcVecPtr;
183
184 blkCnt = numCols >> 4;
185 while (blkCnt > 0U)
186 {
187 vecMatA0 = vld1q(pMat0Vec);
188 pMat0Vec += 16;
189 vecMatA1 = vld1q(pMat1Vec);
190 pMat1Vec += 16;
191 vecIn = vld1q(pVec);
192 pVec += 16;
193
194 acc0 = vmladavaq(acc0, vecIn, vecMatA0);
195 acc1 = vmladavaq(acc1, vecIn, vecMatA1);
196
197 blkCnt--;
198 }
199
200 /*
201 * tail
202 * (will be merged thru tail predication)
203 */
204 blkCnt = numCols & 0xF;
205 if (blkCnt > 0U)
206 {
207 mve_pred16_t p0 = vctp8q(blkCnt);
208
209 vecMatA0 = vld1q(pMat0Vec);
210 vecMatA1 = vld1q(pMat1Vec);
211 vecIn = vldrbq_z_s8(pVec, p0);
212
213 acc0 = vmladavaq(acc0, vecIn, vecMatA0);
214 acc1 = vmladavaq(acc1, vecIn, vecMatA1);
215 }
216
217 *px++ = __SSAT(acc0 >> 7, 8);
218 *px++ = __SSAT(acc1 >> 7, 8);
219
220 pMatSrc += numCols * 2;
221 /*
222 * Decrement the row loop counter
223 */
224 row -= 2;
225 }
226
227 if (row >= 1)
228 {
229 q7_t const *pMat0Vec, *pVec;
230 q7_t const *pSrcVecPtr = pSrcVec;
231 q31_t acc0;
232 q7x16_t vecMatA0, vecIn;
233
234 /*
235 * For every row wise process, the pInVec pointer is set
236 * to the starting address of the vector
237 */
238 pVec = pSrcVec;
239
240 /*
241 * Initialize the pointer pIn1 to point to the starting address of the column being processed
242 */
243 pMat0 = pMatSrc;
244
245 acc0 = 0LL;
246
247 pMat0Vec = pMat0;
248 pVec = pSrcVecPtr;
249
250 blkCnt = numCols >> 4;
251 while (blkCnt > 0U)
252 {
253 vecMatA0 = vld1q(pMat0Vec);
254 pMat0Vec += 16;
255 vecIn = vld1q(pVec);
256 pVec += 16;
257
258 acc0 = vmladavaq(acc0, vecIn, vecMatA0);
259 blkCnt--;
260 }
261 /*
262 * tail
263 * (will be merged thru tail predication)
264 */
265 blkCnt = numCols & 0xF;
266 if (blkCnt > 0U)
267 {
268 mve_pred16_t p0 = vctp8q(blkCnt);
269
270 vecMatA0 = vld1q(pMat0Vec);
271 vecIn = vldrbq_z_s8(pVec, p0);
272 acc0 = vmladavaq(acc0, vecIn, vecMatA0);
273 }
274 *px++ = __SSAT(acc0 >> 7, 8);
275 }
276 }
277
278 #else
arm_mat_vec_mult_q7(const arm_matrix_instance_q7 * pSrcMat,const q7_t * pVec,q7_t * pDst)279 void arm_mat_vec_mult_q7(const arm_matrix_instance_q7 *pSrcMat, const q7_t *pVec, q7_t *pDst)
280 {
281 uint32_t numRows = pSrcMat->numRows;
282 uint32_t numCols = pSrcMat->numCols;
283 const q7_t *pSrcA = pSrcMat->pData;
284 const q7_t *pInA1; /* input data matrix pointer of Q7 type */
285 const q7_t *pInA2; /* input data matrix pointer of Q7 type */
286 const q7_t *pInA3; /* input data matrix pointer of Q7 type */
287 const q7_t *pInA4; /* input data matrix pointer of Q7 type */
288 const q7_t *pInVec; /* input data vector pointer of Q7 type */
289 q7_t *px; /* output data pointer */
290 uint32_t i, row, colCnt; /* loop counters */
291
292 q31_t matData, matData2, vecData, vecData2;
293
294
295 /* Process 4 rows at a time */
296 row = numRows >> 2;
297 i = 0u;
298 px = pDst;
299
300
301
302 /* The following loop performs the dot-product of each row in pSrcA with the vector */
303 while (row > 0) {
304 /* For every row wise process, the pInVec pointer is set
305 ** to the starting address of the vector */
306 pInVec = pVec;
307
308 /* Initialize accumulators */
309 q31_t sum1 = 0;
310 q31_t sum2 = 0;
311 q31_t sum3 = 0;
312 q31_t sum4 = 0;
313
314 /* Loop unrolling: process 4 columns per iteration */
315 colCnt = numCols >> 2;
316
317 /* Initialize row pointers so we can track 4 rows at once */
318 pInA1 = pSrcA + i;
319 pInA2 = pInA1 + numCols;
320 pInA3 = pInA2 + numCols;
321 pInA4 = pInA3 + numCols;
322
323
324 // Inner loop: matrix-vector multiplication
325
326 while (colCnt > 0u) {
327 // Read 4 values from vector
328 vecData = read_q7x4_ia ((q7_t **) &pInVec);
329 vecData2 = __SXTB16(__ROR(vecData, 8));
330 vecData = __SXTB16(vecData);
331 // Read 16 values from the matrix - 4 values from each of 4 rows, and do multiply accumulate
332 matData = read_q7x4_ia ((q7_t **) &pInA1);
333 matData2 = __SXTB16(__ROR(matData, 8));
334 matData = __SXTB16(matData);
335 sum1 = __SMLAD(matData, vecData, sum1);
336 sum1 = __SMLAD(matData2, vecData2, sum1);
337 matData = read_q7x4_ia ((q7_t **) &pInA2);
338 matData2 = __SXTB16(__ROR(matData, 8));
339 matData = __SXTB16(matData);
340 sum2 = __SMLAD(matData, vecData, sum2);
341 sum2 = __SMLAD(matData2, vecData2, sum2);
342 matData = read_q7x4_ia ((q7_t **) &pInA3);
343 matData2 = __SXTB16(__ROR(matData, 8));
344 matData = __SXTB16(matData);
345 sum3 = __SMLAD(matData, vecData, sum3);
346 sum3 = __SMLAD(matData2, vecData2, sum3);
347 matData = read_q7x4_ia ((q7_t **) &pInA4);
348 matData2 = __SXTB16(__ROR(matData, 8));
349 matData = __SXTB16(matData);
350 sum4 = __SMLAD(matData, vecData, sum4);
351 sum4 = __SMLAD(matData2, vecData2, sum4);
352
353 // Decrement the loop counter
354 colCnt--;
355 }
356
357 /* process any remaining columns */
358
359 colCnt = numCols & 3u;
360
361 while (colCnt > 0) {
362 vecData = *pInVec++;
363 sum1 += *pInA1++ * vecData;
364 sum2 += *pInA2++ * vecData;
365 sum3 += *pInA3++ * vecData;
366 sum4 += *pInA4++ * vecData;
367 colCnt--;
368 }
369
370 /* Saturate and store the result in the destination buffer */
371 *px++ = (q7_t)(__SSAT((sum1 >> 7), 8));
372 *px++ = (q7_t)(__SSAT((sum2 >> 7), 8));
373 *px++ = (q7_t)(__SSAT((sum3 >> 7), 8));
374 *px++ = (q7_t)(__SSAT((sum4 >> 7), 8));
375
376 i = i + numCols * 4;
377
378 /* Decrement the row loop counter */
379 row--;
380 }
381
382 /* process any remaining rows */
383 row = numRows & 3u;
384 while (row > 0) {
385
386 q31_t sum = 0;
387 pInVec = pVec;
388 pInA1 = pSrcA + i;
389
390 // loop unrolling - process 4 elements at a time
391 colCnt = numCols >> 2;
392
393 while (colCnt > 0) {
394 vecData = read_q7x4_ia ((q7_t **) &pInVec);
395 vecData2 = __SXTB16(__ROR(vecData, 8));
396 vecData = __SXTB16(vecData);
397 matData = read_q7x4_ia ((q7_t **) &pInA1);
398 matData2 = __SXTB16(__ROR(matData, 8));
399 matData = __SXTB16(matData);
400 sum = __SMLAD(matData, vecData, sum);
401 sum = __SMLAD(matData2, vecData2, sum);
402 colCnt--;
403 }
404
405 // process remainder of row
406 colCnt = numCols & 3u;
407 while (colCnt > 0) {
408 sum += *pInA1++ * *pInVec++;
409 colCnt--;
410 }
411 *px++ = (q7_t)(__SSAT((sum >> 7), 8));
412 i = i + numCols;
413 row--;
414 }
415 }
416 #endif /* defined(ARM_MATH_MVEI) */
417
418 /**
419 * @} end of MatrixMult group
420 */
421