1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_cmplx_mult_q31.c
4 * Description: Floating-point matrix multiplication
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30
31 /**
32 @ingroup groupMatrix
33 */
34
35 /**
36 @addtogroup CmplxMatrixMult
37 @{
38 */
39
40 /**
41 @brief Q31 Complex matrix multiplication.
42 @param[in] pSrcA points to first input complex matrix structure
43 @param[in] pSrcB points to second input complex matrix structure
44 @param[out] pDst points to output complex matrix structure
45 @return execution status
46 - \ref ARM_MATH_SUCCESS : Operation successful
47 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
48
49 @par Scaling and Overflow Behavior
50 The function is implemented using an internal 64-bit accumulator.
51 The accumulator has a 2.62 format and maintains full precision of the intermediate
52 multiplication results but provides only a single guard bit. There is no saturation
53 on intermediate additions. Thus, if the accumulator overflows it wraps around and
54 distorts the result. The input signals should be scaled down to avoid intermediate
55 overflows. The input is thus scaled down by log2(numColsA) bits
56 to avoid overflows, as a total of numColsA additions are performed internally.
57 The 2.62 accumulator is right shifted by 31 bits and saturated to 1.31 format to yield the final result.
58 */
59 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
60
61 #include "arm_helium_utils.h"
62
63 #define MATRIX_DIM2 2
64 #define MATRIX_DIM3 3
65 #define MATRIX_DIM4 4
66
arm_mat_cmplx_mult_q31_2x2_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)67 __STATIC_INLINE arm_status arm_mat_cmplx_mult_q31_2x2_mve(
68 const arm_matrix_instance_q31 * pSrcA,
69 const arm_matrix_instance_q31 * pSrcB,
70 arm_matrix_instance_q31 * pDst)
71 {
72 q31_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
73 q31_t const *pInA = pSrcA->pData; /* input data matrix pointer A */
74 q31_t *pOut = pDst->pData; /* output data matrix pointer */
75 uint32x4_t vecColBOffs0;
76 q31_t const *pInA0 = pInA;
77 q31_t const *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM2;
78 q63_t acc0, acc1, acc2, acc3;
79 q31x4_t vecB, vecA;
80
81 static const uint32_t offsetB0[4] = {
82 0, 1,
83 MATRIX_DIM2 * CMPLX_DIM, MATRIX_DIM2 * CMPLX_DIM + 1
84 };
85
86 vecColBOffs0 = vldrwq_u32(offsetB0);
87
88 pInB = (q31_t const *) pSrcB->pData;
89
90 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
91 vecA = vldrwq_s32(pInA0);
92 acc0 = vmlsldavq_s32(vecA, vecB);
93 acc1 = vmlaldavxq_s32(vecA, vecB);
94
95 vecA = vldrwq_s32(pInA1);
96 acc2 = vmlsldavq_s32(vecA, vecB);
97 acc3 = vmlaldavxq_s32(vecA, vecB);
98
99 pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 0] = (q31_t) asrl(acc0, 31);
100 pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 1] = (q31_t) asrl(acc1, 31);
101 pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 0] = (q31_t) asrl(acc2, 31);
102 pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 1] = (q31_t) asrl(acc3, 31);
103 /*
104 * move to next B column
105 */
106 pInB = pInB + CMPLX_DIM;
107
108 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
109 vecA = vldrwq_s32(pInA0);
110 acc0 = vmlsldavq_s32(vecA, vecB);
111 acc1 = vmlaldavxq_s32(vecA, vecB);
112
113 vecA = vldrwq_s32(pInA1);
114 acc2 = vmlsldavq_s32(vecA, vecB);
115 acc3 = vmlaldavxq_s32(vecA, vecB);
116
117 pOut += CMPLX_DIM;
118
119 pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 0] = (q31_t) asrl(acc0, 31);
120 pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 1] = (q31_t) asrl(acc1, 31);
121 pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 0] = (q31_t) asrl(acc2, 31);
122 pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 1] = (q31_t) asrl(acc3, 31);
123 /*
124 * Return to application
125 */
126 return (ARM_MATH_SUCCESS);
127 }
128
arm_mat_cmplx_mult_q31_3x3_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)129 __STATIC_INLINE arm_status arm_mat_cmplx_mult_q31_3x3_mve(
130 const arm_matrix_instance_q31 * pSrcA,
131 const arm_matrix_instance_q31 * pSrcB,
132 arm_matrix_instance_q31 * pDst)
133 {
134 q31_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
135 q31_t const *pInA = pSrcA->pData; /* input data matrix pointer A */
136 q31_t *pOut = pDst->pData; /* output data matrix pointer */
137 uint32x4_t vecColBOffs0, vecColBOffs1;
138 q31_t const *pInA0 = pInA;
139 q31_t const *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM3;
140 q31_t const *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM3;
141 q63_t acc0, acc1, acc2, acc3;
142 q31x4_t vecB, vecB1, vecA;
143 /*
144 * enable predication to disable upper half complex vector element
145 */
146 mve_pred16_t p0 = vctp32q(CMPLX_DIM);
147
148 static const uint32_t offsetB0[4] = {
149 0, 1,
150 MATRIX_DIM3 * CMPLX_DIM, MATRIX_DIM3 * CMPLX_DIM + 1
151 };
152 static const uint32_t offsetB1[4] = {
153 2 * MATRIX_DIM3 * CMPLX_DIM, 2 * MATRIX_DIM3 * CMPLX_DIM + 1,
154 INACTIVELANE, INACTIVELANE
155 };
156
157 vecColBOffs0 = vldrwq_u32(offsetB0);
158 vecColBOffs1 = vldrwq_u32(offsetB1);
159
160 pInB = (q31_t const *) pSrcB->pData;
161
162 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
163 vecB1 = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
164
165 vecA = vldrwq_s32(pInA0);
166 acc0 = vmlsldavq_s32(vecA, vecB);
167 acc1 = vmlaldavxq_s32(vecA, vecB);
168
169 vecA = vldrwq_s32(pInA1);
170 acc2 = vmlsldavq_s32(vecA, vecB);
171 acc3 = vmlaldavxq_s32(vecA, vecB);
172
173 vecA = vldrwq_z_s32(&pInA0[4], p0);
174 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
175 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
176
177 vecA = vldrwq_z_s32(&pInA1[4], p0);
178 acc2 = vmlsldavaq_s32(acc2, vecA, vecB1);
179 acc3 = vmlaldavaxq_s32(acc3, vecA, vecB1);
180
181 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 0] = (q31_t) asrl(acc0, 31);
182 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 1] = (q31_t) asrl(acc1, 31);
183 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 0] = (q31_t) asrl(acc2, 31);
184 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 1] = (q31_t) asrl(acc3, 31);
185
186 vecA = vldrwq_s32(pInA2);
187 acc0 = vmlsldavq_s32(vecA, vecB);
188 acc1 = vmlaldavxq_s32(vecA, vecB);
189
190 vecA = vldrwq_z_s32(&pInA2[4], p0);
191 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
192 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
193
194 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 0] = (q31_t) asrl(acc0, 31);
195 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 1] = (q31_t) asrl(acc1, 31);
196 pOut += CMPLX_DIM;
197
198 /*
199 * move to next B column
200 */
201 pInB = pInB + CMPLX_DIM;
202
203 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
204 vecB1 = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
205
206 vecA = vldrwq_s32(pInA0);
207 acc0 = vmlsldavq_s32(vecA, vecB);
208 acc1 = vmlaldavxq_s32(vecA, vecB);
209
210 vecA = vldrwq_s32(pInA1);
211 acc2 = vmlsldavq_s32(vecA, vecB);
212 acc3 = vmlaldavxq_s32(vecA, vecB);
213
214 vecA = vldrwq_z_s32(&pInA0[4], p0);
215 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
216 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
217
218 vecA = vldrwq_z_s32(&pInA1[4], p0);
219 acc2 = vmlsldavaq_s32(acc2, vecA, vecB1);
220 acc3 = vmlaldavaxq_s32(acc3, vecA, vecB1);
221
222 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 0] = (q31_t) asrl(acc0, 31);
223 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 1] = (q31_t) asrl(acc1, 31);
224 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 0] = (q31_t) asrl(acc2, 31);
225 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 1] = (q31_t) asrl(acc3, 31);
226
227 vecA = vldrwq_s32(pInA2);
228 acc0 = vmlsldavq_s32(vecA, vecB);
229 acc1 = vmlaldavxq_s32(vecA, vecB);
230
231 vecA = vldrwq_z_s32(&pInA2[4], p0);
232 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
233 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
234
235 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 0] = (q31_t) asrl(acc0, 31);
236 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 1] = (q31_t) asrl(acc1, 31);
237 pOut += CMPLX_DIM;
238
239 /*
240 * move to next B column
241 */
242 pInB = pInB + CMPLX_DIM;
243
244 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
245 vecB1 = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
246
247 vecA = vldrwq_s32(pInA0);
248 acc0 = vmlsldavq_s32(vecA, vecB);
249 acc1 = vmlaldavxq_s32(vecA, vecB);
250
251 vecA = vldrwq_s32(pInA1);
252 acc2 = vmlsldavq_s32(vecA, vecB);
253 acc3 = vmlaldavxq_s32(vecA, vecB);
254
255 vecA = vldrwq_z_s32(&pInA0[4], p0);
256 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
257 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
258
259 vecA = vldrwq_z_s32(&pInA1[4], p0);
260 acc2 = vmlsldavaq_s32(acc2, vecA, vecB1);
261 acc3 = vmlaldavaxq_s32(acc3, vecA, vecB1);
262
263 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 0] = (q31_t) asrl(acc0, 31);
264 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 1] = (q31_t) asrl(acc1, 31);
265 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 0] = (q31_t) asrl(acc2, 31);
266 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 1] = (q31_t) asrl(acc3, 31);
267
268 vecA = vldrwq_s32(pInA2);
269 acc0 = vmlsldavq_s32(vecA, vecB);
270 acc1 = vmlaldavxq_s32(vecA, vecB);
271
272 vecA = vldrwq_z_s32(&pInA2[4], p0);
273 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
274 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
275
276 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 0] = (q31_t) asrl(acc0, 31);
277 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 1] = (q31_t) asrl(acc1, 31);
278 /*
279 * Return to application
280 */
281 return (ARM_MATH_SUCCESS);
282 }
283
arm_mat_cmplx_mult_q31_4x4_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)284 __STATIC_INLINE arm_status arm_mat_cmplx_mult_q31_4x4_mve(
285 const arm_matrix_instance_q31 * pSrcA,
286 const arm_matrix_instance_q31 * pSrcB,
287 arm_matrix_instance_q31 * pDst)
288 {
289 q31_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
290 q31_t const *pInA = pSrcA->pData; /* input data matrix pointer A */
291 q31_t *pOut = pDst->pData; /* output data matrix pointer */
292 uint32x4_t vecColBOffs0, vecColBOffs1;
293 q31_t const *pInA0 = pInA;
294 q31_t const *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM4;
295 q31_t const *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM4;
296 q31_t const *pInA3 = pInA2 + CMPLX_DIM * MATRIX_DIM4;
297 q63_t acc0, acc1, acc2, acc3;
298 q31x4_t vecB, vecB1, vecA;
299
300 static const uint32_t offsetB0[4] = {
301 0, 1,
302 MATRIX_DIM4 * CMPLX_DIM, MATRIX_DIM4 * CMPLX_DIM + 1
303 };
304 static const uint32_t offsetB1[4] = {
305 2 * MATRIX_DIM4 * CMPLX_DIM, 2 * MATRIX_DIM4 * CMPLX_DIM + 1,
306 3 * MATRIX_DIM4 * CMPLX_DIM, 3 * MATRIX_DIM4 * CMPLX_DIM + 1
307 };
308
309 vecColBOffs0 = vldrwq_u32(offsetB0);
310 vecColBOffs1 = vldrwq_u32(offsetB1);
311
312 pInB = (q31_t const *) pSrcB->pData;
313
314 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
315 vecB1 = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
316
317 vecA = vldrwq_s32(pInA0);
318 acc0 = vmlsldavq_s32(vecA, vecB);
319 acc1 = vmlaldavxq_s32(vecA, vecB);
320
321 vecA = vldrwq_s32(pInA1);
322 acc2 = vmlsldavq_s32(vecA, vecB);
323 acc3 = vmlaldavxq_s32(vecA, vecB);
324
325 vecA = vldrwq_s32(&pInA0[4]);
326 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
327 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
328
329 vecA = vldrwq_s32(&pInA1[4]);
330 acc2 = vmlsldavaq_s32(acc2, vecA, vecB1);
331 acc3 = vmlaldavaxq_s32(acc3, vecA, vecB1);
332
333 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc0, 31);
334 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc1, 31);
335 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc2, 31);
336 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc3, 31);
337
338 vecA = vldrwq_s32(pInA2);
339 acc0 = vmlsldavq_s32(vecA, vecB);
340 acc1 = vmlaldavxq_s32(vecA, vecB);
341
342 vecA = vldrwq_s32(pInA3);
343 acc2 = vmlsldavq_s32(vecA, vecB);
344 acc3 = vmlaldavxq_s32(vecA, vecB);
345
346 vecA = vldrwq_s32(&pInA2[4]);
347 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
348 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
349
350 vecA = vldrwq_s32(&pInA3[4]);
351 acc2 = vmlsldavaq_s32(acc2, vecA, vecB1);
352 acc3 = vmlaldavaxq_s32(acc3, vecA, vecB1);
353
354 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc0, 31);
355 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc1, 31);
356 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc2, 31);
357 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc3, 31);
358 pOut += CMPLX_DIM;
359
360 /*
361 * move to next B column
362 */
363 pInB = pInB + CMPLX_DIM;
364
365 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
366 vecB1 = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
367
368 vecA = vldrwq_s32(pInA0);
369 acc0 = vmlsldavq_s32(vecA, vecB);
370 acc1 = vmlaldavxq_s32(vecA, vecB);
371
372 vecA = vldrwq_s32(pInA1);
373 acc2 = vmlsldavq_s32(vecA, vecB);
374 acc3 = vmlaldavxq_s32(vecA, vecB);
375
376 vecA = vldrwq_s32(&pInA0[4]);
377 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
378 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
379
380 vecA = vldrwq_s32(&pInA1[4]);
381 acc2 = vmlsldavaq_s32(acc2, vecA, vecB1);
382 acc3 = vmlaldavaxq_s32(acc3, vecA, vecB1);
383
384 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc0, 31);
385 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc1, 31);
386 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc2, 31);
387 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc3, 31);
388
389 vecA = vldrwq_s32(pInA2);
390 acc0 = vmlsldavq_s32(vecA, vecB);
391 acc1 = vmlaldavxq_s32(vecA, vecB);
392
393 vecA = vldrwq_s32(pInA3);
394 acc2 = vmlsldavq_s32(vecA, vecB);
395 acc3 = vmlaldavxq_s32(vecA, vecB);
396
397 vecA = vldrwq_s32(&pInA2[4]);
398 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
399 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
400
401 vecA = vldrwq_s32(&pInA3[4]);
402 acc2 = vmlsldavaq_s32(acc2, vecA, vecB1);
403 acc3 = vmlaldavaxq_s32(acc3, vecA, vecB1);
404
405 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc0, 31);
406 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc1, 31);
407 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc2, 31);
408 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc3, 31);
409 pOut += CMPLX_DIM;
410 /*
411 * move to next B column
412 */
413 pInB = pInB + CMPLX_DIM;
414
415 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
416 vecB1 = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
417
418 vecA = vldrwq_s32(pInA0);
419 acc0 = vmlsldavq_s32(vecA, vecB);
420 acc1 = vmlaldavxq_s32(vecA, vecB);
421
422 vecA = vldrwq_s32(pInA1);
423 acc2 = vmlsldavq_s32(vecA, vecB);
424 acc3 = vmlaldavxq_s32(vecA, vecB);
425
426 vecA = vldrwq_s32(&pInA0[4]);
427 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
428 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
429
430 vecA = vldrwq_s32(&pInA1[4]);
431 acc2 = vmlsldavaq_s32(acc2, vecA, vecB1);
432 acc3 = vmlaldavaxq_s32(acc3, vecA, vecB1);
433
434 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc0, 31);
435 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc1, 31);
436 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc2, 31);
437 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc3, 31);
438
439 vecA = vldrwq_s32(pInA2);
440 acc0 = vmlsldavq_s32(vecA, vecB);
441 acc1 = vmlaldavxq_s32(vecA, vecB);
442
443 vecA = vldrwq_s32(pInA3);
444 acc2 = vmlsldavq_s32(vecA, vecB);
445 acc3 = vmlaldavxq_s32(vecA, vecB);
446
447 vecA = vldrwq_s32(&pInA2[4]);
448 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
449 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
450
451 vecA = vldrwq_s32(&pInA3[4]);
452 acc2 = vmlsldavaq_s32(acc2, vecA, vecB1);
453 acc3 = vmlaldavaxq_s32(acc3, vecA, vecB1);
454
455 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc0, 31);
456 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc1, 31);
457 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc2, 31);
458 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc3, 31);
459 pOut += CMPLX_DIM;
460
461 /*
462 * move to next B column
463 */
464 pInB = pInB + CMPLX_DIM;
465
466 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
467 vecB1 = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
468
469 vecA = vldrwq_s32(pInA0);
470 acc0 = vmlsldavq_s32(vecA, vecB);
471 acc1 = vmlaldavxq_s32(vecA, vecB);
472
473 vecA = vldrwq_s32(pInA1);
474 acc2 = vmlsldavq_s32(vecA, vecB);
475 acc3 = vmlaldavxq_s32(vecA, vecB);
476
477 vecA = vldrwq_s32(&pInA0[4]);
478 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
479 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
480
481 vecA = vldrwq_s32(&pInA1[4]);
482 acc2 = vmlsldavaq_s32(acc2, vecA, vecB1);
483 acc3 = vmlaldavaxq_s32(acc3, vecA, vecB1);
484
485 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc0, 31);
486 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc1, 31);
487 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc2, 31);
488 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc3, 31);
489
490 vecA = vldrwq_s32(pInA2);
491 acc0 = vmlsldavq_s32(vecA, vecB);
492 acc1 = vmlaldavxq_s32(vecA, vecB);
493
494 vecA = vldrwq_s32(pInA3);
495 acc2 = vmlsldavq_s32(vecA, vecB);
496 acc3 = vmlaldavxq_s32(vecA, vecB);
497
498 vecA = vldrwq_s32(&pInA2[4]);
499 acc0 = vmlsldavaq_s32(acc0, vecA, vecB1);
500 acc1 = vmlaldavaxq_s32(acc1, vecA, vecB1);
501
502 vecA = vldrwq_s32(&pInA3[4]);
503 acc2 = vmlsldavaq_s32(acc2, vecA, vecB1);
504 acc3 = vmlaldavaxq_s32(acc3, vecA, vecB1);
505
506 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc0, 31);
507 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc1, 31);
508 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = (q31_t) asrl(acc2, 31);
509 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = (q31_t) asrl(acc3, 31);
510 /*
511 * Return to application
512 */
513 return (ARM_MATH_SUCCESS);
514 }
515
516
arm_mat_cmplx_mult_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)517 arm_status arm_mat_cmplx_mult_q31(
518 const arm_matrix_instance_q31 * pSrcA,
519 const arm_matrix_instance_q31 * pSrcB,
520 arm_matrix_instance_q31 * pDst)
521 {
522 q31_t const *pInB = (q31_t const *) pSrcB->pData; /* input data matrix pointer B */
523 q31_t const *pInA = (q31_t const *) pSrcA->pData; /* input data matrix pointer A */
524 q31_t *pOut = pDst->pData; /* output data matrix pointer */
525 q31_t *px; /* Temporary output data matrix pointer */
526 uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
527 uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
528 uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
529 uint16_t col, i = 0U, row = numRowsA; /* loop counters */
530 arm_status status; /* status of matrix multiplication */
531 uint32x4_t vecOffs, vecColBOffs;
532 uint32_t blkCnt, rowCnt; /* loop counters */
533
534 #ifdef ARM_MATH_MATRIX_CHECK
535
536 /* Check for matrix mismatch condition */
537 if ((pSrcA->numCols != pSrcB->numRows) ||
538 (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
539 {
540
541 /* Set status as ARM_MATH_SIZE_MISMATCH */
542 status = ARM_MATH_SIZE_MISMATCH;
543 }
544 else
545 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
546
547 {
548 /*
549 * small squared matrix specialized routines
550 */
551 if (numRowsA == numColsB && numColsB == numColsA)
552 {
553 if (numRowsA == 1)
554 {
555 q63_t sumReal = (q63_t) pInA[0] * pInB[0];
556 sumReal -= (q63_t) pInA[1] * pInB[1];
557
558 q63_t sumImag = (q63_t) pInA[0] * pInB[1];
559 sumImag += (q63_t) pInA[1] * pInB[0];
560
561 /* Store result in destination buffer */
562 pOut[0] = (q31_t) clip_q63_to_q31(sumReal >> 31);
563 pOut[1] = (q31_t) clip_q63_to_q31(sumImag >> 31);
564 return (ARM_MATH_SUCCESS);
565 }
566 else if (numRowsA == 2)
567 return arm_mat_cmplx_mult_q31_2x2_mve(pSrcA, pSrcB, pDst);
568 else if (numRowsA == 3)
569 return arm_mat_cmplx_mult_q31_3x3_mve(pSrcA, pSrcB, pDst);
570 else if (numRowsA == 4)
571 return arm_mat_cmplx_mult_q31_4x4_mve(pSrcA, pSrcB, pDst);
572 }
573
574 vecColBOffs[0] = 0;
575 vecColBOffs[1] = 1;
576 vecColBOffs[2] = numColsB * CMPLX_DIM;
577 vecColBOffs[3] = (numColsB * CMPLX_DIM) + 1;
578
579 /*
580 * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
581 */
582
583 /*
584 * row loop
585 */
586 rowCnt = row >> 1;
587 while (rowCnt > 0u)
588 {
589 /*
590 * Output pointer is set to starting address of the row being processed
591 */
592 px = pOut + i * CMPLX_DIM;
593 i = i + 2 * numColsB;
594 /*
595 * For every row wise process, the column loop counter is to be initiated
596 */
597 col = numColsB;
598 /*
599 * For every row wise process, the pInB pointer is set
600 * to the starting address of the pSrcB data
601 */
602 pInB = (q31_t const *) pSrcB->pData;
603 /*
604 * column loop
605 */
606 while (col > 0u)
607 {
608 /*
609 * generate 4 columns elements
610 */
611 /*
612 * Matrix A columns number of MAC operations are to be performed
613 */
614
615 q31_t const *pSrcA0Vec, *pSrcA1Vec;
616 q31_t const *pInA0 = pInA;
617 q31_t const *pInA1 = pInA0 + numColsA * CMPLX_DIM;
618 q63_t acc0, acc1, acc2, acc3;
619
620 acc0 = 0LL;
621 acc1 = 0LL;
622 acc2 = 0LL;
623 acc3 = 0LL;
624
625 pSrcA0Vec = (q31_t const *) pInA0;
626 pSrcA1Vec = (q31_t const *) pInA1;
627
628
629 vecOffs = vecColBOffs;
630
631 /*
632 * process 1 x 2 block output
633 */
634 blkCnt = (numColsA * CMPLX_DIM) >> 2;
635 while (blkCnt > 0U)
636 {
637 q31x4_t vecB, vecA;
638
639 vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
640 /*
641 * move Matrix B read offsets, 2 rows down
642 */
643 vecOffs = vecOffs + (uint32_t) (numColsB * 2 * CMPLX_DIM);
644
645
646 vecA = vld1q(pSrcA0Vec);
647 pSrcA0Vec += 4;
648 acc0 = vmlsldavaq(acc0, vecA, vecB);
649 acc1 = vmlaldavaxq(acc1, vecA, vecB);
650
651
652 vecA = vld1q(pSrcA1Vec);
653 pSrcA1Vec += 4;
654
655 acc2 = vmlsldavaq(acc2, vecA, vecB);
656 acc3 = vmlaldavaxq(acc3, vecA, vecB);
657
658
659 blkCnt--;
660 }
661
662
663 /*
664 * tail
665 */
666 blkCnt = (numColsA * CMPLX_DIM) & 3;
667 if (blkCnt > 0U)
668 {
669 mve_pred16_t p0 = vctp32q(blkCnt);
670 q31x4_t vecB, vecA;
671
672 vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
673
674 /*
675 * move Matrix B read offsets, 2 rows down
676 */
677 vecOffs = vecOffs + (uint32_t) (numColsB * 2 * CMPLX_DIM);
678
679
680 vecA = vld1q(pSrcA0Vec);
681 acc0 = vmlsldavaq(acc0, vecA, vecB);
682 acc1 = vmlaldavaxq(acc1, vecA, vecB);
683 vecA = vld1q(pSrcA1Vec);
684 acc2 = vmlsldavaq(acc2, vecA, vecB);
685 acc3 = vmlaldavaxq(acc3, vecA, vecB);
686
687
688 }
689
690 px[0 * CMPLX_DIM * numColsB + 0] = (q31_t) clip_q63_to_q31(acc0 >> 31);
691 px[0 * CMPLX_DIM * numColsB + 1] = (q31_t) clip_q63_to_q31(acc1 >> 31);
692 px[1 * CMPLX_DIM * numColsB + 0] = (q31_t) clip_q63_to_q31(acc2 >> 31);
693 px[1 * CMPLX_DIM * numColsB + 1] = (q31_t) clip_q63_to_q31(acc3 >> 31);
694 px += CMPLX_DIM;
695 /*
696 * Decrement the column loop counter
697 */
698 col--;
699 /*
700 * Update the pointer pInB to point to the starting address of the next column
701 */
702 pInB = (q31_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
703 }
704
705 /*
706 * Update the pointer pInA to point to the starting address of the next row
707 */
708 pInA += (numColsA * 2) * CMPLX_DIM;
709 /*
710 * Decrement the row loop counter
711 */
712 rowCnt --;
713
714 }
715
716 rowCnt = row & 1;
717 while (rowCnt > 0u)
718 {
719 /*
720 * Output pointer is set to starting address of the row being processed
721 */
722 px = pOut + i * CMPLX_DIM;
723 i = i + numColsB;
724 /*
725 * For every row wise process, the column loop counter is to be initiated
726 */
727 col = numColsB;
728 /*
729 * For every row wise process, the pInB pointer is set
730 * to the starting address of the pSrcB data
731 */
732 pInB = (q31_t const *) pSrcB->pData;
733 /*
734 * column loop
735 */
736 while (col > 0u)
737 {
738 /*
739 * generate 4 columns elements
740 */
741 /*
742 * Matrix A columns number of MAC operations are to be performed
743 */
744
745 q31_t const *pSrcA0Vec;
746 q31_t const *pInA0 = pInA;
747 q63_t acc0,acc1;
748
749 acc0 = 0LL;
750 acc1 = 0LL;
751
752 pSrcA0Vec = (q31_t const *) pInA0;
753
754 vecOffs = vecColBOffs;
755
756 /*
757 * process 1 x 2 block output
758 */
759 blkCnt = (numColsA * CMPLX_DIM) >> 2;
760 while (blkCnt > 0U)
761 {
762 q31x4_t vecB, vecA;
763
764 vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
765 /*
766 * move Matrix B read offsets, 2 rows down
767 */
768 vecOffs = vecOffs + (uint32_t) (numColsB * 2 * CMPLX_DIM);
769
770 vecA = vld1q(pSrcA0Vec);
771 pSrcA0Vec += 4;
772 acc0 = vmlsldavaq(acc0, vecA, vecB);
773 acc1 = vmlaldavaxq(acc1, vecA, vecB);
774
775
776 blkCnt--;
777 }
778
779
780 /*
781 * tail
782 */
783 blkCnt = (numColsA * CMPLX_DIM) & 3;
784 if (blkCnt > 0U)
785 {
786 mve_pred16_t p0 = vctp32q(blkCnt);
787 q31x4_t vecB, vecA;
788
789 vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
790
791 /*
792 * move Matrix B read offsets, 2 rows down
793 */
794 vecOffs = vecOffs + (uint32_t) (numColsB * 2 * CMPLX_DIM);
795
796 vecA = vld1q(pSrcA0Vec);
797
798
799 acc0 = vmlsldavaq(acc0, vecA, vecB);
800 acc1 = vmlaldavaxq(acc1, vecA, vecB);
801
802
803 }
804
805 px[0] = (q31_t) clip_q63_to_q31(acc0 >> 31);
806 px[1] = (q31_t) clip_q63_to_q31(acc1 >> 31);
807
808
809 px += CMPLX_DIM;
810 /*
811 * Decrement the column loop counter
812 */
813 col--;
814 /*
815 * Update the pointer pInB to point to the starting address of the next column
816 */
817 pInB = (q31_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
818 }
819
820 /*
821 * Update the pointer pInA to point to the starting address of the next row
822 */
823 pInA += numColsA * CMPLX_DIM;
824 rowCnt--;
825 }
826
827
828 /* Set status as ARM_MATH_SUCCESS */
829 status = ARM_MATH_SUCCESS;
830 }
831
832 /* Return to application */
833 return (status);
834 }
835
836 #else
arm_mat_cmplx_mult_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)837 arm_status arm_mat_cmplx_mult_q31(
838 const arm_matrix_instance_q31 * pSrcA,
839 const arm_matrix_instance_q31 * pSrcB,
840 arm_matrix_instance_q31 * pDst)
841 {
842 q31_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
843 q31_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
844 q31_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
845 q31_t *pOut = pDst->pData; /* Output data matrix pointer */
846 q31_t *px; /* Temporary output data matrix pointer */
847 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
848 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
849 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
850 q63_t sumReal, sumImag; /* Accumulator */
851 q31_t a1, b1, c1, d1;
852 uint32_t col, i = 0U, j, row = numRowsA, colCnt; /* loop counters */
853 arm_status status; /* status of matrix multiplication */
854
855 #if defined (ARM_MATH_LOOPUNROLL)
856 q31_t a0, b0, c0, d0;
857 #endif
858
859 #ifdef ARM_MATH_MATRIX_CHECK
860
861 /* Check for matrix mismatch condition */
862 if ((pSrcA->numCols != pSrcB->numRows) ||
863 (pSrcA->numRows != pDst->numRows) ||
864 (pSrcB->numCols != pDst->numCols) )
865 {
866 /* Set status as ARM_MATH_SIZE_MISMATCH */
867 status = ARM_MATH_SIZE_MISMATCH;
868 }
869 else
870
871 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
872
873 {
874 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
875 /* row loop */
876 do
877 {
878 /* Output pointer is set to starting address of the row being processed */
879 px = pOut + 2 * i;
880
881 /* For every row wise process, the column loop counter is to be initiated */
882 col = numColsB;
883
884 /* For every row wise process, the pIn2 pointer is set
885 ** to the starting address of the pSrcB data */
886 pIn2 = pSrcB->pData;
887
888 j = 0U;
889
890 /* column loop */
891 do
892 {
893 /* Set the variable sum, that acts as accumulator, to zero */
894 sumReal = 0.0;
895 sumImag = 0.0;
896
897 /* Initiate pointer pIn1 to point to starting address of column being processed */
898 pIn1 = pInA;
899
900 #if defined (ARM_MATH_LOOPUNROLL)
901
902 /* Apply loop unrolling and compute 4 MACs simultaneously. */
903 colCnt = numColsA >> 2U;
904
905 /* matrix multiplication */
906 while (colCnt > 0U)
907 {
908
909 /* Reading real part of complex matrix A */
910 a0 = *pIn1;
911
912 /* Reading real part of complex matrix B */
913 c0 = *pIn2;
914
915 /* Reading imaginary part of complex matrix A */
916 b0 = *(pIn1 + 1U);
917
918 /* Reading imaginary part of complex matrix B */
919 d0 = *(pIn2 + 1U);
920
921 /* Multiply and Accumlates */
922 sumReal += (q63_t) a0 * c0;
923 sumImag += (q63_t) b0 * c0;
924
925 /* update pointers */
926 pIn1 += 2U;
927 pIn2 += 2 * numColsB;
928
929 /* Multiply and Accumlates */
930 sumReal -= (q63_t) b0 * d0;
931 sumImag += (q63_t) a0 * d0;
932
933 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
934
935 /* read real and imag values from pSrcA and pSrcB buffer */
936 a1 = *(pIn1 );
937 c1 = *(pIn2 );
938 b1 = *(pIn1 + 1U);
939 d1 = *(pIn2 + 1U);
940
941 /* Multiply and Accumlates */
942 sumReal += (q63_t) a1 * c1;
943 sumImag += (q63_t) b1 * c1;
944
945 /* update pointers */
946 pIn1 += 2U;
947 pIn2 += 2 * numColsB;
948
949 /* Multiply and Accumlates */
950 sumReal -= (q63_t) b1 * d1;
951 sumImag += (q63_t) a1 * d1;
952
953 a0 = *(pIn1 );
954 c0 = *(pIn2 );
955 b0 = *(pIn1 + 1U);
956 d0 = *(pIn2 + 1U);
957
958 /* Multiply and Accumlates */
959 sumReal += (q63_t) a0 * c0;
960 sumImag += (q63_t) b0 * c0;
961
962 /* update pointers */
963 pIn1 += 2U;
964 pIn2 += 2 * numColsB;
965
966 /* Multiply and Accumlates */
967 sumReal -= (q63_t) b0 * d0;
968 sumImag += (q63_t) a0 * d0;
969
970 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
971
972 a1 = *(pIn1 );
973 c1 = *(pIn2 );
974 b1 = *(pIn1 + 1U);
975 d1 = *(pIn2 + 1U);
976
977 /* Multiply and Accumlates */
978 sumReal += (q63_t) a1 * c1;
979 sumImag += (q63_t) b1 * c1;
980
981 /* update pointers */
982 pIn1 += 2U;
983 pIn2 += 2 * numColsB;
984
985 /* Multiply and Accumlates */
986 sumReal -= (q63_t) b1 * d1;
987 sumImag += (q63_t) a1 * d1;
988
989 /* Decrement loop count */
990 colCnt--;
991 }
992
993 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
994 ** No loop unrolling is used. */
995 colCnt = numColsA % 0x4U;
996
997 #else
998
999 /* Initialize blkCnt with number of samples */
1000 colCnt = numColsA;
1001
1002 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
1003
1004 while (colCnt > 0U)
1005 {
1006 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
1007 a1 = *(pIn1 );
1008 c1 = *(pIn2 );
1009 b1 = *(pIn1 + 1U);
1010 d1 = *(pIn2 + 1U);
1011
1012 /* Multiply and Accumlates */
1013 sumReal += (q63_t) a1 * c1;
1014 sumImag += (q63_t) b1 * c1;
1015
1016 /* update pointers */
1017 pIn1 += 2U;
1018 pIn2 += 2 * numColsB;
1019
1020 /* Multiply and Accumlates */
1021 sumReal -= (q63_t) b1 * d1;
1022 sumImag += (q63_t) a1 * d1;
1023
1024 /* Decrement loop counter */
1025 colCnt--;
1026 }
1027
1028 /* Store result in destination buffer */
1029 *px++ = (q31_t) clip_q63_to_q31(sumReal >> 31);
1030 *px++ = (q31_t) clip_q63_to_q31(sumImag >> 31);
1031
1032 /* Update pointer pIn2 to point to starting address of next column */
1033 j++;
1034 pIn2 = pSrcB->pData + 2U * j;
1035
1036 /* Decrement column loop counter */
1037 col--;
1038
1039 } while (col > 0U);
1040
1041 /* Update pointer pInA to point to starting address of next row */
1042 i = i + numColsB;
1043 pInA = pInA + 2 * numColsA;
1044
1045 /* Decrement row loop counter */
1046 row--;
1047
1048 } while (row > 0U);
1049
1050 /* Set status as ARM_MATH_SUCCESS */
1051 status = ARM_MATH_SUCCESS;
1052 }
1053
1054 /* Return to application */
1055 return (status);
1056 }
1057 #endif /* defined(ARM_MATH_MVEI) */
1058
1059 /**
1060 @} end of MatrixMult group
1061 */
1062