1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_mult_fast_q31.c
4 * Description: Q31 matrix multiplication (fast variant)
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30
31 /**
32 @ingroup groupMatrix
33 */
34
35 /**
36 @addtogroup MatrixMult
37 @{
38 */
39
40 /**
41 @brief Q31 matrix multiplication (fast variant).
42 @param[in] pSrcA points to the first input matrix structure
43 @param[in] pSrcB points to the second input matrix structure
44 @param[out] pDst points to output matrix structure
45 @return execution status
46 - \ref ARM_MATH_SUCCESS : Operation successful
47 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
48
49 @par Scaling and Overflow Behavior
50 The difference between the function \ref arm_mat_mult_q31() and this fast variant is that
51 the fast variant use a 32-bit rather than a 64-bit accumulator.
52 The result of each 1.31 x 1.31 multiplication is truncated to
53 2.30 format. These intermediate results are accumulated in a 32-bit register in 2.30
54 format. Finally, the accumulator is saturated and converted to a 1.31 result.
55 @par
56 The fast version has the same overflow behavior as the standard version but provides
57 less precision since it discards the low 32 bits of each multiplication result.
58 In order to avoid overflows completely the input signals must be scaled down.
59 Scale down one of the input matrices by log2(numColsA) bits to avoid overflows,
60 as a total of numColsA additions are computed internally for each output element.
61 @remark
62 Refer to \ref arm_mat_mult_q31() for a slower implementation of this function
63 which uses 64-bit accumulation to provide higher precision.
64 */
65
arm_mat_mult_fast_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)66 arm_status arm_mat_mult_fast_q31(
67 const arm_matrix_instance_q31 * pSrcA,
68 const arm_matrix_instance_q31 * pSrcB,
69 arm_matrix_instance_q31 * pDst)
70 {
71 q31_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
72 q31_t *pInB = pSrcB->pData; /* Input data matrix pointer B */
73 q31_t *pInA2;
74 q31_t *px; /* Temporary output data matrix pointer */
75 q31_t *px2;
76 q31_t sum1, sum2, sum3, sum4; /* Accumulator */
77 q31_t inA1, inA2, inB1, inB2;
78 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
79 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
80 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
81 uint32_t col, i = 0U, j, row = numRowsA, colCnt; /* Loop counters */
82 arm_status status; /* Status of matrix multiplication */
83
84
85 #ifdef ARM_MATH_MATRIX_CHECK
86
87 /* Check for matrix mismatch condition */
88 if ((pSrcA->numCols != pSrcB->numRows) ||
89 (pSrcA->numRows != pDst->numRows) ||
90 (pSrcB->numCols != pDst->numCols) )
91 {
92 /* Set status as ARM_MATH_SIZE_MISMATCH */
93 status = ARM_MATH_SIZE_MISMATCH;
94 }
95 else
96
97 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
98
99 {
100 px = pDst->pData;
101
102 row = row >> 1U;
103 px2 = px + numColsB;
104
105 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
106 /* row loop */
107 while (row > 0U)
108 {
109 /* For every row wise process, column loop counter is to be initiated */
110 col = numColsB;
111
112 /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
113 pInB = pSrcB->pData;
114
115 j = 0U;
116
117 col = col >> 1U;
118
119 /* column loop */
120 while (col > 0U)
121 {
122 /* Set the variable sum, that acts as accumulator, to zero */
123 sum1 = 0;
124 sum2 = 0;
125 sum3 = 0;
126 sum4 = 0;
127
128 /* Initiate data pointers */
129 pInA = pSrcA->pData + i;
130 pInB = pSrcB->pData + j;
131 pInA2 = pInA + numColsA;
132
133 colCnt = numColsA;
134
135 /* matrix multiplication */
136 while (colCnt > 0U)
137 {
138 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
139
140 inA1 = *pInA++;
141 inB1 = pInB[0];
142 inA2 = *pInA2++;
143 inB2 = pInB[1];
144 pInB += numColsB;
145
146 #if defined (ARM_MATH_DSP)
147 sum1 = __SMMLA(inA1, inB1, sum1);
148 sum2 = __SMMLA(inA1, inB2, sum2);
149 sum3 = __SMMLA(inA2, inB1, sum3);
150 sum4 = __SMMLA(inA2, inB2, sum4);
151 #else
152 sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) inA1 * inB1)) >> 32);
153 sum2 = (q31_t) ((((q63_t) sum2 << 32) + ((q63_t) inA1 * inB2)) >> 32);
154 sum3 = (q31_t) ((((q63_t) sum3 << 32) + ((q63_t) inA2 * inB1)) >> 32);
155 sum4 = (q31_t) ((((q63_t) sum4 << 32) + ((q63_t) inA2 * inB2)) >> 32);
156 #endif
157
158 /* Decrement loop counter */
159 colCnt--;
160 }
161
162 /* Convert the result from 2.30 to 1.31 format and store in destination buffer */
163 *px++ = sum1 << 1;
164 *px++ = sum2 << 1;
165 *px2++ = sum3 << 1;
166 *px2++ = sum4 << 1;
167
168 j += 2;
169
170 /* Decrement column loop counter */
171 col--;
172 }
173
174 i = i + (numColsA << 1U);
175 px = px2 + (numColsB & 1U);
176 px2 = px + numColsB;
177
178 /* Decrement row loop counter */
179 row--;
180 }
181
182 /* Compute any remaining odd row/column below */
183
184 /* Compute remaining output column */
185 if (numColsB & 1U) {
186
187 /* Avoid redundant computation of last element */
188 row = numRowsA & (~1U);
189
190 /* Point to remaining unfilled column in output matrix */
191 px = pDst->pData + numColsB-1;
192 pInA = pSrcA->pData;
193
194 /* row loop */
195 while (row > 0)
196 {
197
198 /* point to last column in matrix B */
199 pInB = pSrcB->pData + numColsB-1;
200
201 /* Set variable sum1, that acts as accumulator, to zero */
202 sum1 = 0;
203
204 #if defined (ARM_MATH_LOOPUNROLL)
205
206 /* Loop unrolling: Compute 4 columns at a time. */
207 colCnt = numColsA >> 2U;
208
209 /* matrix multiplication */
210 while (colCnt > 0U)
211 {
212 #if defined (ARM_MATH_DSP)
213 sum1 = __SMMLA(*pInA++, *pInB, sum1);
214 #else
215 sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
216 #endif
217 pInB += numColsB;
218
219 #if defined (ARM_MATH_DSP)
220 sum1 = __SMMLA(*pInA++, *pInB, sum1);
221 #else
222 sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
223 #endif
224 pInB += numColsB;
225
226 #if defined (ARM_MATH_DSP)
227 sum1 = __SMMLA(*pInA++, *pInB, sum1);
228 #else
229 sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
230 #endif
231 pInB += numColsB;
232
233 #if defined (ARM_MATH_DSP)
234 sum1 = __SMMLA(*pInA++, *pInB, sum1);
235 #else
236 sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
237 #endif
238 pInB += numColsB;
239
240 /* Decrement loop counter */
241 colCnt--;
242 }
243
244 /* Loop unrolling: Compute remaining column */
245 colCnt = numColsA % 4U;
246
247 #else
248
249 /* Initialize colCnt with number of columns */
250 colCnt = numColsA;
251
252 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
253
254 while (colCnt > 0U) {
255 #if defined (ARM_MATH_DSP)
256 sum1 = __SMMLA(*pInA++, *pInB, sum1);
257 #else
258 sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
259 #endif
260 pInB += numColsB;
261
262 colCnt--;
263 }
264
265 /* Convert the result from 2.30 to 1.31 format and store in destination buffer */
266 *px = sum1 << 1;
267 px += numColsB;
268
269 /* Decrement row loop counter */
270 row--;
271 }
272 }
273
274 /* Compute remaining output row */
275 if (numRowsA & 1U) {
276
277 /* point to last row in output matrix */
278 px = pDst->pData + (numColsB) * (numRowsA-1);
279
280 col = numColsB;
281 i = 0U;
282
283 /* col loop */
284 while (col > 0)
285 {
286
287 /* point to last row in matrix A */
288 pInA = pSrcA->pData + (numRowsA-1) * numColsA;
289 pInB = pSrcB->pData + i;
290
291 /* Set variable sum1, that acts as accumulator, to zero */
292 sum1 = 0;
293
294 #if defined (ARM_MATH_LOOPUNROLL)
295
296 /* Loop unrolling: Compute 4 columns at a time. */
297 colCnt = numColsA >> 2U;
298
299 /* matrix multiplication */
300 while (colCnt > 0U)
301 {
302 inA1 = *pInA++;
303 inA2 = *pInA++;
304 inB1 = *pInB;
305 pInB += numColsB;
306 inB2 = *pInB;
307 pInB += numColsB;
308 #if defined (ARM_MATH_DSP)
309 sum1 = __SMMLA(inA1, inB1, sum1);
310 sum1 = __SMMLA(inA2, inB2, sum1);
311 #else
312 sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) inA1 * inB1)) >> 32);
313 sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) inA2 * inB2)) >> 32);
314 #endif
315
316 inA1 = *pInA++;
317 inA2 = *pInA++;
318 inB1 = *pInB;
319 pInB += numColsB;
320 inB2 = *pInB;
321 pInB += numColsB;
322 #if defined (ARM_MATH_DSP)
323 sum1 = __SMMLA(inA1, inB1, sum1);
324 sum1 = __SMMLA(inA2, inB2, sum1);
325 #else
326 sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) inA1 * inB1)) >> 32);
327 sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) inA2 * inB2)) >> 32);
328 #endif
329
330 /* Decrement loop counter */
331 colCnt--;
332 }
333
334 /* Loop unrolling: Compute remaining column */
335 colCnt = numColsA % 4U;
336
337 #else
338
339 /* Initialize colCnt with number of columns */
340 colCnt = numColsA;
341
342 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
343
344 while (colCnt > 0U) {
345 #if defined (ARM_MATH_DSP)
346 sum1 = __SMMLA(*pInA++, *pInB, sum1);
347 #else
348 sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
349 #endif
350 pInB += numColsB;
351
352 colCnt--;
353 }
354
355 /* Saturate and store the result in the destination buffer */
356 *px++ = sum1 << 1;
357 i++;
358
359 /* Decrement col loop counter */
360 col--;
361 }
362 }
363
364 /* Set status as ARM_MATH_SUCCESS */
365 status = ARM_MATH_SUCCESS;
366 }
367
368 /* Return to application */
369 return (status);
370 }
371
372 /**
373 @} end of MatrixMult group
374 */
375