1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_cmplx_mult_f16.c
4 * Description: Floating-point matrix multiplication
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions_f16.h"
30
31 #if defined(ARM_FLOAT16_SUPPORTED)
32
33
34 /**
35 @ingroup groupMatrix
36 */
37
38
39 /**
40 @addtogroup CmplxMatrixMult
41 @{
42 */
43
44 /**
45 @brief Floating-point Complex matrix multiplication.
46 @param[in] pSrcA points to first input complex matrix structure
47 @param[in] pSrcB points to second input complex matrix structure
48 @param[out] pDst points to output complex matrix structure
49 @return execution status
50 - \ref ARM_MATH_SUCCESS : Operation successful
51 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
52 */
53
54 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) && defined(__CMSIS_GCC_H)
55 #pragma GCC warning "Scalar version of arm_mat_cmplx_mult_f16 built. Helium version has build issues with gcc."
56 #endif
57
58 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) && !defined(__CMSIS_GCC_H)
59
60 #include "arm_helium_utils.h"
61
62 #define DONTCARE 0 /* inactive lane content */
63
64
arm_mat_cmplx_mult_f16_2x2_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)65 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_2x2_mve(
66 const arm_matrix_instance_f16 * pSrcA,
67 const arm_matrix_instance_f16 * pSrcB,
68 arm_matrix_instance_f16 * pDst)
69 {
70 const uint16_t MATRIX_DIM = 2;
71 float16_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
72 float16_t *pInA = pSrcA->pData; /* input data matrix pointer A */
73 float16_t *pOut = pDst->pData; /* output data matrix pointer */
74 uint16x8_t vecColBOffs0,vecColAOffs0,vecColAOffs1;
75 float16_t *pInA0 = pInA;
76 f16x8_t acc0, acc1;
77 f16x8_t vecB, vecA0, vecA1;
78 f16x8_t vecTmp;
79 uint16_t tmp;
80 static const uint16_t offsetB0[8] = { 0, 1,
81 MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
82 2, 3,
83 MATRIX_DIM * CMPLX_DIM + 2 , MATRIX_DIM * CMPLX_DIM + 3,
84 };
85
86
87 vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
88
89 tmp = 0;
90 vecColAOffs0 = viwdupq_u16(tmp, 4, 1);
91
92 tmp = (CMPLX_DIM * MATRIX_DIM);
93 vecColAOffs1 = vecColAOffs0 + (uint16_t)(CMPLX_DIM * MATRIX_DIM);
94
95
96 pInB = (float16_t const *)pSrcB->pData;
97
98 vecA0 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs0);
99 vecA1 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs1);
100
101
102 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
103
104 acc0 = vcmulq(vecA0, vecB);
105 acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
106
107 acc1 = vcmulq(vecA1, vecB);
108 acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
109
110
111 /*
112 * Compute
113 * re0+re1 | im0+im1 | re0+re1 | im0+im1
114 * re2+re3 | im2+im3 | re2+re3 | im2+im3
115 */
116
117 vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc0);
118 vecTmp = vaddq(vecTmp, acc0);
119
120
121 *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0];
122 *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2];
123
124 vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc1);
125 vecTmp = vaddq(vecTmp, acc1);
126
127 *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0];
128 *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2];
129
130 /*
131 * Return to application
132 */
133 return (ARM_MATH_SUCCESS);
134 }
135
136
137
arm_mat_cmplx_mult_f16_3x3_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)138 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_3x3_mve(
139 const arm_matrix_instance_f16 * pSrcA,
140 const arm_matrix_instance_f16 * pSrcB,
141 arm_matrix_instance_f16 * pDst)
142 {
143 const uint16_t MATRIX_DIM = 3;
144 float16_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
145 float16_t *pInA = pSrcA->pData; /* input data matrix pointer A */
146 float16_t *pOut = pDst->pData; /* output data matrix pointer */
147 uint16x8_t vecColBOffs0;
148 float16_t *pInA0 = pInA;
149 float16_t *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM;
150 float16_t *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM;
151 f16x8_t acc0, acc1, acc2;
152 f16x8_t vecB, vecA0, vecA1, vecA2;
153 static const uint16_t offsetB0[8] = { 0, 1,
154 MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
155 2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1,
156 DONTCARE, DONTCARE
157 };
158
159
160 /* enable predication to disable upper half complex vector element */
161 mve_pred16_t p0 = vctp16q(MATRIX_DIM * CMPLX_DIM);
162
163 vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
164
165 pInB = (float16_t const *)pSrcB->pData;
166
167 vecA0 = vldrhq_f16(pInA0);
168 vecA1 = vldrhq_f16(pInA1);
169 vecA2 = vldrhq_f16(pInA2);
170
171 vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
172
173 acc0 = vcmulq(vecA0, vecB);
174 acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
175
176 acc1 = vcmulq(vecA1, vecB);
177 acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
178
179 acc2 = vcmulq(vecA2, vecB);
180 acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
181
182 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
183 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
184 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
185 pOut += CMPLX_DIM;
186 /*
187 * move to next B column
188 */
189 pInB = pInB + CMPLX_DIM;
190
191 vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
192
193 acc0 = vcmulq(vecA0, vecB);
194 acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
195
196 acc1 = vcmulq(vecA1, vecB);
197 acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
198
199 acc2 = vcmulq(vecA2, vecB);
200 acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
201
202 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
203 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
204 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
205 pOut += CMPLX_DIM;
206 /*
207 * move to next B column
208 */
209 pInB = pInB + CMPLX_DIM;
210
211 vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
212
213 acc0 = vcmulq(vecA0, vecB);
214 acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
215
216 acc1 = vcmulq(vecA1, vecB);
217 acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
218
219 acc2 = vcmulq(vecA2, vecB);
220 acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
221
222 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
223 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
224 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
225 /*
226 * Return to application
227 */
228 return (ARM_MATH_SUCCESS);
229 }
230
231
232
233
arm_mat_cmplx_mult_f16_4x4_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)234 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_4x4_mve(
235 const arm_matrix_instance_f16 * pSrcA,
236 const arm_matrix_instance_f16 * pSrcB,
237 arm_matrix_instance_f16 * pDst)
238 {
239 const uint16_t MATRIX_DIM = 4;
240 float16_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
241 float16_t *pInA = pSrcA->pData; /* input data matrix pointer A */
242 float16_t *pOut = pDst->pData; /* output data matrix pointer */
243 uint16x8_t vecColBOffs0;
244 float16_t *pInA0 = pInA;
245 float16_t *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM;
246 float16_t *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM;
247 float16_t *pInA3 = pInA2 + CMPLX_DIM * MATRIX_DIM;
248 f16x8_t acc0, acc1, acc2, acc3;
249 f16x8_t vecB, vecA;
250 static const uint16_t offsetB0[8] = { 0, 1,
251 MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
252 2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1,
253 3 * MATRIX_DIM * CMPLX_DIM, 3 * MATRIX_DIM * CMPLX_DIM + 1
254 };
255
256 vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
257
258 pInB = (float16_t const *)pSrcB->pData;
259
260 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
261
262 vecA = vldrhq_f16(pInA0);
263 acc0 = vcmulq(vecA, vecB);
264 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
265
266 vecA = vldrhq_f16(pInA1);
267 acc1 = vcmulq(vecA, vecB);
268 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
269
270 vecA = vldrhq_f16(pInA2);
271 acc2 = vcmulq(vecA, vecB);
272 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
273
274 vecA = vldrhq_f16(pInA3);
275 acc3 = vcmulq(vecA, vecB);
276 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
277
278
279 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
280 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
281 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
282 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
283 pOut += CMPLX_DIM;
284 /*
285 * move to next B column
286 */
287 pInB = pInB + CMPLX_DIM;
288
289 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
290
291 vecA = vldrhq_f16(pInA0);
292 acc0 = vcmulq(vecA, vecB);
293 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
294
295 vecA = vldrhq_f16(pInA1);
296 acc1 = vcmulq(vecA, vecB);
297 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
298
299 vecA = vldrhq_f16(pInA2);
300 acc2 = vcmulq(vecA, vecB);
301 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
302
303 vecA = vldrhq_f16(pInA3);
304 acc3 = vcmulq(vecA, vecB);
305 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
306
307
308 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
309 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
310 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
311 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
312 pOut += CMPLX_DIM;
313 /*
314 * move to next B column
315 */
316 pInB = pInB + CMPLX_DIM;
317
318 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
319
320 vecA = vldrhq_f16(pInA0);
321 acc0 = vcmulq(vecA, vecB);
322 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
323
324 vecA = vldrhq_f16(pInA1);
325 acc1 = vcmulq(vecA, vecB);
326 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
327
328 vecA = vldrhq_f16(pInA2);
329 acc2 = vcmulq(vecA, vecB);
330 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
331
332 vecA = vldrhq_f16(pInA3);
333 acc3 = vcmulq(vecA, vecB);
334 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
335
336
337 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
338 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
339 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
340 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
341 pOut += CMPLX_DIM;
342 /*
343 * move to next B column
344 */
345 pInB = pInB + CMPLX_DIM;
346
347 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
348
349 vecA = vldrhq_f16(pInA0);
350 acc0 = vcmulq(vecA, vecB);
351 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
352
353 vecA = vldrhq_f16(pInA1);
354 acc1 = vcmulq(vecA, vecB);
355 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
356
357 vecA = vldrhq_f16(pInA2);
358 acc2 = vcmulq(vecA, vecB);
359 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
360
361 vecA = vldrhq_f16(pInA3);
362 acc3 = vcmulq(vecA, vecB);
363 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
364
365
366 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
367 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
368 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
369 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
370 /*
371 * Return to application
372 */
373 return (ARM_MATH_SUCCESS);
374 }
375
376
377
arm_mat_cmplx_mult_f16(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)378 arm_status arm_mat_cmplx_mult_f16(
379 const arm_matrix_instance_f16 * pSrcA,
380 const arm_matrix_instance_f16 * pSrcB,
381 arm_matrix_instance_f16 * pDst)
382 {
383 float16_t const *pInB = (float16_t const *) pSrcB->pData; /* input data matrix pointer B */
384 float16_t const *pInA = (float16_t const *) pSrcA->pData; /* input data matrix pointer A */
385 float16_t *pOut = pDst->pData; /* output data matrix pointer */
386 float16_t *px; /* Temporary output data matrix pointer */
387 uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
388 uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
389 uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
390 uint16_t col, i = 0U, row = numRowsA; /* loop counters */
391 arm_status status; /* status of matrix multiplication */
392 uint16x8_t vecOffs, vecColBOffs;
393 uint32_t blkCnt,rowCnt; /* loop counters */
394
395 #ifdef ARM_MATH_MATRIX_CHECK
396
397 /* Check for matrix mismatch condition */
398 if ((pSrcA->numCols != pSrcB->numRows) ||
399 (pSrcA->numRows != pDst->numRows) ||
400 (pSrcB->numCols != pDst->numCols) )
401 {
402 /* Set status as ARM_MATH_SIZE_MISMATCH */
403 status = ARM_MATH_SIZE_MISMATCH;
404 }
405 else
406
407 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
408
409 {
410
411 /*
412 * small squared matrix specialized routines
413 */
414 if (numRowsA == numColsB && numColsB == numColsA)
415 {
416 if (numRowsA == 1)
417 {
418 pOut[0] = pInA[0] * pInB[0] - pInA[1] * pInB[1];
419 pOut[1] = pInA[0] * pInB[1] + pInA[1] * pInB[0];
420 return (ARM_MATH_SUCCESS);
421 }
422 else if (numRowsA == 2)
423 return arm_mat_cmplx_mult_f16_2x2_mve(pSrcA, pSrcB, pDst);
424 else if (numRowsA == 3)
425 return arm_mat_cmplx_mult_f16_3x3_mve(pSrcA, pSrcB, pDst);
426 else if (numRowsA == 4)
427 return arm_mat_cmplx_mult_f16_4x4_mve(pSrcA, pSrcB, pDst);
428 }
429
430 vecColBOffs[0] = 0;
431 vecColBOffs[1] = 1;
432 vecColBOffs[2] = numColsB * CMPLX_DIM;
433 vecColBOffs[3] = (numColsB * CMPLX_DIM) + 1;
434 vecColBOffs[4] = 2*numColsB * CMPLX_DIM;
435 vecColBOffs[5] = 2*(numColsB * CMPLX_DIM) + 1;
436 vecColBOffs[6] = 3*numColsB * CMPLX_DIM;
437 vecColBOffs[7] = 3*(numColsB * CMPLX_DIM) + 1;
438
439 /*
440 * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
441 */
442
443 /*
444 * row loop
445 */
446 rowCnt = row >> 2;
447 while (rowCnt > 0u)
448 {
449 /*
450 * Output pointer is set to starting address of the row being processed
451 */
452 px = pOut + i * CMPLX_DIM;
453 i = i + 4 * numColsB;
454 /*
455 * For every row wise process, the column loop counter is to be initiated
456 */
457 col = numColsB;
458 /*
459 * For every row wise process, the pInB pointer is set
460 * to the starting address of the pSrcB data
461 */
462 pInB = (float16_t const *) pSrcB->pData;
463 /*
464 * column loop
465 */
466 while (col > 0u)
467 {
468 /*
469 * generate 4 columns elements
470 */
471 /*
472 * Matrix A columns number of MAC operations are to be performed
473 */
474
475 float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
476 float16_t const *pInA0 = pInA;
477 float16_t const *pInA1 = pInA0 + numColsA * CMPLX_DIM;
478 float16_t const *pInA2 = pInA1 + numColsA * CMPLX_DIM;
479 float16_t const *pInA3 = pInA2 + numColsA * CMPLX_DIM;
480 f16x8_t acc0, acc1, acc2, acc3;
481
482 acc0 = vdupq_n_f16(0.0f16);
483 acc1 = vdupq_n_f16(0.0f16);
484 acc2 = vdupq_n_f16(0.0f16);
485 acc3 = vdupq_n_f16(0.0f16);
486
487 pSrcA0Vec = (float16_t const *) pInA0;
488 pSrcA1Vec = (float16_t const *) pInA1;
489 pSrcA2Vec = (float16_t const *) pInA2;
490 pSrcA3Vec = (float16_t const *) pInA3;
491
492 vecOffs = vecColBOffs;
493
494 /*
495 * process 1 x 4 block output
496 */
497 blkCnt = (numColsA * CMPLX_DIM) >> 3;
498 while (blkCnt > 0U)
499 {
500 f16x8_t vecB, vecA;
501
502 vecB = vldrhq_gather_shifted_offset_f16(pInB, vecOffs);
503 /*
504 * move Matrix B read offsets, 4 rows down
505 */
506 vecOffs = vaddq_n_u16(vecOffs , (uint16_t) (numColsB * 4 * CMPLX_DIM));
507
508 vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 8;
509 acc0 = vcmlaq(acc0, vecA, vecB);
510 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
511
512 vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 8;
513 acc1 = vcmlaq(acc1, vecA, vecB);
514 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
515
516 vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 8;
517 acc2 = vcmlaq(acc2, vecA, vecB);
518 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
519
520 vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 8;
521 acc3 = vcmlaq(acc3, vecA, vecB);
522 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
523
524 blkCnt--;
525 }
526 /*
527 * Unsupported addressing mode compiler crash
528 */
529 /*
530 * tail
531 * (will be merged thru tail predication)
532 */
533 blkCnt = (numColsA * CMPLX_DIM) & 7;
534 if (blkCnt > 0U)
535 {
536 mve_pred16_t p0 = vctp16q(blkCnt);
537 f16x8_t vecB, vecA;
538
539 vecB = vldrhq_gather_shifted_offset_z_f16(pInB, vecOffs, p0);
540 /*
541 * move Matrix B read offsets, 4 rows down
542 */
543 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM));
544
545 vecA = vld1q(pSrcA0Vec);
546 acc0 = vcmlaq(acc0, vecA, vecB);
547 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
548
549 vecA = vld1q(pSrcA1Vec);
550 acc1 = vcmlaq(acc1, vecA, vecB);
551 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
552
553 vecA = vld1q(pSrcA2Vec);
554 acc2 = vcmlaq(acc2, vecA, vecB);
555 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
556
557 vecA = vld1q(pSrcA3Vec);
558 acc3 = vcmlaq(acc3, vecA, vecB);
559 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
560
561 }
562
563
564 mve_cmplx_sum_intra_vec_f16(acc0, &px[0 * CMPLX_DIM * numColsB + 0]);
565 mve_cmplx_sum_intra_vec_f16(acc1, &px[1 * CMPLX_DIM * numColsB + 0]);
566 mve_cmplx_sum_intra_vec_f16(acc2, &px[2 * CMPLX_DIM * numColsB + 0]);
567 mve_cmplx_sum_intra_vec_f16(acc3, &px[3 * CMPLX_DIM * numColsB + 0]);
568
569 px += CMPLX_DIM;
570 /*
571 * Decrement the column loop counter
572 */
573 col--;
574 /*
575 * Update the pointer pInB to point to the starting address of the next column
576 */
577 pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
578 }
579
580 /*
581 * Update the pointer pInA to point to the starting address of the next row
582 */
583 pInA += (numColsA * 4) * CMPLX_DIM;
584 /*
585 * Decrement the row loop counter
586 */
587 rowCnt --;
588
589 }
590
591 rowCnt = row & 3;
592 while (rowCnt > 0u)
593 {
594 /*
595 * Output pointer is set to starting address of the row being processed
596 */
597 px = pOut + i * CMPLX_DIM;
598 i = i + numColsB;
599 /*
600 * For every row wise process, the column loop counter is to be initiated
601 */
602 col = numColsB;
603 /*
604 * For every row wise process, the pInB pointer is set
605 * to the starting address of the pSrcB data
606 */
607 pInB = (float16_t const *) pSrcB->pData;
608 /*
609 * column loop
610 */
611 while (col > 0u)
612 {
613 /*
614 * generate 4 columns elements
615 */
616 /*
617 * Matrix A columns number of MAC operations are to be performed
618 */
619
620 float16_t const *pSrcA0Vec;
621 float16_t const *pInA0 = pInA;
622 f16x8_t acc0;
623
624 acc0 = vdupq_n_f16(0.0f16);
625
626 pSrcA0Vec = (float16_t const *) pInA0;
627
628 vecOffs = vecColBOffs;
629
630 /*
631 * process 1 x 4 block output
632 */
633 blkCnt = (numColsA * CMPLX_DIM) >> 3;
634 while (blkCnt > 0U)
635 {
636 f16x8_t vecB, vecA;
637
638 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
639 /*
640 * move Matrix B read offsets, 4 rows down
641 */
642 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (4*numColsB * CMPLX_DIM));
643
644 vecA = vld1q(pSrcA0Vec);
645 pSrcA0Vec += 8;
646 acc0 = vcmlaq(acc0, vecA, vecB);
647 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
648
649
650 blkCnt--;
651 }
652
653
654 /*
655 * tail
656 */
657 blkCnt = (numColsA * CMPLX_DIM) & 7;
658 if (blkCnt > 0U)
659 {
660 mve_pred16_t p0 = vctp16q(blkCnt);
661 f16x8_t vecB, vecA;
662
663 vecB = vldrhq_gather_shifted_offset_z(pInB, vecOffs, p0);
664
665 vecA = vld1q(pSrcA0Vec);
666 acc0 = vcmlaq(acc0, vecA, vecB);
667 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
668
669 }
670
671 mve_cmplx_sum_intra_vec_f16(acc0, &px[0]);
672
673
674 px += CMPLX_DIM;
675 /*
676 * Decrement the column loop counter
677 */
678 col--;
679 /*
680 * Update the pointer pInB to point to the starting address of the next column
681 */
682 pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
683 }
684
685 /*
686 * Update the pointer pInA to point to the starting address of the next row
687 */
688 pInA += numColsA * CMPLX_DIM;
689 rowCnt--;
690 }
691
692 /*
693 * set status as ARM_MATH_SUCCESS
694 */
695 status = ARM_MATH_SUCCESS;
696 }
697 /*
698 * Return to application
699 */
700 return (status);
701 }
702 #else
703
arm_mat_cmplx_mult_f16(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)704 arm_status arm_mat_cmplx_mult_f16(
705 const arm_matrix_instance_f16 * pSrcA,
706 const arm_matrix_instance_f16 * pSrcB,
707 arm_matrix_instance_f16 * pDst)
708 {
709 float16_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
710 float16_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
711 float16_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
712 float16_t *pOut = pDst->pData; /* Output data matrix pointer */
713 float16_t *px; /* Temporary output data matrix pointer */
714 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
715 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
716 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
717 _Float16 sumReal, sumImag; /* Accumulator */
718 _Float16 a1, b1, c1, d1;
719 uint32_t col, i = 0U, j, row = numRowsA, colCnt; /* loop counters */
720 arm_status status; /* status of matrix multiplication */
721
722 #if defined (ARM_MATH_LOOPUNROLL)
723 _Float16 a0, b0, c0, d0;
724 #endif
725
726 #ifdef ARM_MATH_MATRIX_CHECK
727
728 /* Check for matrix mismatch condition */
729 if ((pSrcA->numCols != pSrcB->numRows) ||
730 (pSrcA->numRows != pDst->numRows) ||
731 (pSrcB->numCols != pDst->numCols) )
732 {
733 /* Set status as ARM_MATH_SIZE_MISMATCH */
734 status = ARM_MATH_SIZE_MISMATCH;
735 }
736 else
737
738 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
739
740 {
741 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
742 /* row loop */
743 do
744 {
745 /* Output pointer is set to starting address of the row being processed */
746 px = pOut + 2 * i;
747
748 /* For every row wise process, the column loop counter is to be initiated */
749 col = numColsB;
750
751 /* For every row wise process, the pIn2 pointer is set
752 ** to the starting address of the pSrcB data */
753 pIn2 = pSrcB->pData;
754
755 j = 0U;
756
757 /* column loop */
758 do
759 {
760 /* Set the variable sum, that acts as accumulator, to zero */
761 sumReal = 0.0f16;
762 sumImag = 0.0f16;
763
764 /* Initiate pointer pIn1 to point to starting address of column being processed */
765 pIn1 = pInA;
766
767 #if defined (ARM_MATH_LOOPUNROLL)
768
769 /* Apply loop unrolling and compute 4 MACs simultaneously. */
770 colCnt = numColsA >> 2U;
771
772 /* matrix multiplication */
773 while (colCnt > 0U)
774 {
775
776 /* Reading real part of complex matrix A */
777 a0 = *pIn1;
778
779 /* Reading real part of complex matrix B */
780 c0 = *pIn2;
781
782 /* Reading imaginary part of complex matrix A */
783 b0 = *(pIn1 + 1U);
784
785 /* Reading imaginary part of complex matrix B */
786 d0 = *(pIn2 + 1U);
787
788 /* Multiply and Accumlates */
789 sumReal += a0 * c0;
790 sumImag += b0 * c0;
791
792 /* update pointers */
793 pIn1 += 2U;
794 pIn2 += 2 * numColsB;
795
796 /* Multiply and Accumlates */
797 sumReal -= b0 * d0;
798 sumImag += a0 * d0;
799
800 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
801
802 /* read real and imag values from pSrcA and pSrcB buffer */
803 a1 = *(pIn1 );
804 c1 = *(pIn2 );
805 b1 = *(pIn1 + 1U);
806 d1 = *(pIn2 + 1U);
807
808 /* Multiply and Accumlates */
809 sumReal += a1 * c1;
810 sumImag += b1 * c1;
811
812 /* update pointers */
813 pIn1 += 2U;
814 pIn2 += 2 * numColsB;
815
816 /* Multiply and Accumlates */
817 sumReal -= b1 * d1;
818 sumImag += a1 * d1;
819
820 a0 = *(pIn1 );
821 c0 = *(pIn2 );
822 b0 = *(pIn1 + 1U);
823 d0 = *(pIn2 + 1U);
824
825 /* Multiply and Accumlates */
826 sumReal += a0 * c0;
827 sumImag += b0 * c0;
828
829 /* update pointers */
830 pIn1 += 2U;
831 pIn2 += 2 * numColsB;
832
833 /* Multiply and Accumlates */
834 sumReal -= b0 * d0;
835 sumImag += a0 * d0;
836
837 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
838
839 a1 = *(pIn1 );
840 c1 = *(pIn2 );
841 b1 = *(pIn1 + 1U);
842 d1 = *(pIn2 + 1U);
843
844 /* Multiply and Accumlates */
845 sumReal += a1 * c1;
846 sumImag += b1 * c1;
847
848 /* update pointers */
849 pIn1 += 2U;
850 pIn2 += 2 * numColsB;
851
852 /* Multiply and Accumlates */
853 sumReal -= b1 * d1;
854 sumImag += a1 * d1;
855
856 /* Decrement loop count */
857 colCnt--;
858 }
859
860 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
861 ** No loop unrolling is used. */
862 colCnt = numColsA % 0x4U;
863
864 #else
865
866 /* Initialize blkCnt with number of samples */
867 colCnt = numColsA;
868
869 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
870
871 while (colCnt > 0U)
872 {
873 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
874 a1 = *(pIn1 );
875 c1 = *(pIn2 );
876 b1 = *(pIn1 + 1U);
877 d1 = *(pIn2 + 1U);
878
879 /* Multiply and Accumlates */
880 sumReal += a1 * c1;
881 sumImag += b1 * c1;
882
883 /* update pointers */
884 pIn1 += 2U;
885 pIn2 += 2 * numColsB;
886
887 /* Multiply and Accumlates */
888 sumReal -= b1 * d1;
889 sumImag += a1 * d1;
890
891 /* Decrement loop counter */
892 colCnt--;
893 }
894
895 /* Store result in destination buffer */
896 *px++ = sumReal;
897 *px++ = sumImag;
898
899 /* Update pointer pIn2 to point to starting address of next column */
900 j++;
901 pIn2 = pSrcB->pData + 2U * j;
902
903 /* Decrement column loop counter */
904 col--;
905
906 } while (col > 0U);
907
908 /* Update pointer pInA to point to starting address of next row */
909 i = i + numColsB;
910 pInA = pInA + 2 * numColsA;
911
912 /* Decrement row loop counter */
913 row--;
914
915 } while (row > 0U);
916
917 /* Set status as ARM_MATH_SUCCESS */
918 status = ARM_MATH_SUCCESS;
919 }
920
921 /* Return to application */
922 return (status);
923 }
924
925 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
926
927 /**
928 @} end of MatrixMult group
929 */
930
931 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
932
933