• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_cmplx_mult_f16.c
4  * Description:  Floating-point matrix multiplication
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions_f16.h"
30 
31 #if defined(ARM_FLOAT16_SUPPORTED)
32 
33 
34 /**
35   @ingroup groupMatrix
36  */
37 
38 
39 /**
40   @addtogroup CmplxMatrixMult
41   @{
42  */
43 
44 /**
45   @brief         Floating-point Complex matrix multiplication.
46   @param[in]     pSrcA      points to first input complex matrix structure
47   @param[in]     pSrcB      points to second input complex matrix structure
48   @param[out]    pDst       points to output complex matrix structure
49   @return        execution status
50                    - \ref ARM_MATH_SUCCESS       : Operation successful
51                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
52  */
53 
54 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) && defined(__CMSIS_GCC_H)
55 #pragma GCC warning "Scalar version of arm_mat_cmplx_mult_f16 built. Helium version has build issues with gcc."
56 #endif
57 
58 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) &&  !defined(__CMSIS_GCC_H)
59 
60 #include "arm_helium_utils.h"
61 
62 #define DONTCARE            0 /* inactive lane content */
63 
64 
arm_mat_cmplx_mult_f16_2x2_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)65 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_2x2_mve(
66     const arm_matrix_instance_f16 * pSrcA,
67     const arm_matrix_instance_f16 * pSrcB,
68     arm_matrix_instance_f16 * pDst)
69 {
70     const uint16_t   MATRIX_DIM = 2;
71     float16_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
72     float16_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
73     float16_t       *pOut = pDst->pData;   /* output data matrix pointer */
74     uint16x8_t     vecColBOffs0,vecColAOffs0,vecColAOffs1;
75     float16_t       *pInA0 = pInA;
76     f16x8_t        acc0, acc1;
77     f16x8_t        vecB, vecA0, vecA1;
78     f16x8_t        vecTmp;
79     uint16_t         tmp;
80     static const uint16_t offsetB0[8] = { 0, 1,
81         MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
82         2, 3,
83         MATRIX_DIM * CMPLX_DIM + 2 , MATRIX_DIM * CMPLX_DIM + 3,
84     };
85 
86 
87     vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
88 
89     tmp = 0;
90     vecColAOffs0 = viwdupq_u16(tmp, 4, 1);
91 
92     tmp = (CMPLX_DIM * MATRIX_DIM);
93     vecColAOffs1 = vecColAOffs0 + (uint16_t)(CMPLX_DIM * MATRIX_DIM);
94 
95 
96     pInB = (float16_t const *)pSrcB->pData;
97 
98     vecA0 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs0);
99     vecA1 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs1);
100 
101 
102     vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
103 
104     acc0 = vcmulq(vecA0, vecB);
105     acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
106 
107     acc1 = vcmulq(vecA1, vecB);
108     acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
109 
110 
111     /*
112      * Compute
113      *  re0+re1 | im0+im1 | re0+re1 | im0+im1
114      *  re2+re3 | im2+im3 | re2+re3 | im2+im3
115      */
116 
117     vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc0);
118     vecTmp = vaddq(vecTmp, acc0);
119 
120 
121     *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0];
122     *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2];
123 
124     vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc1);
125     vecTmp = vaddq(vecTmp, acc1);
126 
127     *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0];
128     *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2];
129 
130     /*
131      * Return to application
132      */
133     return (ARM_MATH_SUCCESS);
134 }
135 
136 
137 
arm_mat_cmplx_mult_f16_3x3_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)138 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_3x3_mve(
139     const arm_matrix_instance_f16 * pSrcA,
140     const arm_matrix_instance_f16 * pSrcB,
141     arm_matrix_instance_f16 * pDst)
142 {
143     const uint16_t   MATRIX_DIM = 3;
144     float16_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
145     float16_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
146     float16_t       *pOut = pDst->pData;   /* output data matrix pointer */
147     uint16x8_t     vecColBOffs0;
148     float16_t       *pInA0 = pInA;
149     float16_t       *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM;
150     float16_t       *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM;
151     f16x8_t        acc0, acc1, acc2;
152     f16x8_t        vecB, vecA0, vecA1, vecA2;
153     static const uint16_t offsetB0[8] = { 0, 1,
154         MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
155         2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1,
156         DONTCARE, DONTCARE
157     };
158 
159 
160     /* enable predication to disable upper half complex vector element */
161     mve_pred16_t p0 = vctp16q(MATRIX_DIM * CMPLX_DIM);
162 
163     vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
164 
165     pInB = (float16_t const *)pSrcB->pData;
166 
167     vecA0 = vldrhq_f16(pInA0);
168     vecA1 = vldrhq_f16(pInA1);
169     vecA2 = vldrhq_f16(pInA2);
170 
171     vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
172 
173     acc0 = vcmulq(vecA0, vecB);
174     acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
175 
176     acc1 = vcmulq(vecA1, vecB);
177     acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
178 
179     acc2 = vcmulq(vecA2, vecB);
180     acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
181 
182     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
183     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
184     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
185     pOut += CMPLX_DIM;
186     /*
187      * move to next B column
188      */
189     pInB = pInB + CMPLX_DIM;
190 
191     vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
192 
193     acc0 = vcmulq(vecA0, vecB);
194     acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
195 
196     acc1 = vcmulq(vecA1, vecB);
197     acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
198 
199     acc2 = vcmulq(vecA2, vecB);
200     acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
201 
202     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
203     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
204     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
205     pOut += CMPLX_DIM;
206     /*
207      * move to next B column
208      */
209     pInB = pInB + CMPLX_DIM;
210 
211     vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
212 
213     acc0 = vcmulq(vecA0, vecB);
214     acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
215 
216     acc1 = vcmulq(vecA1, vecB);
217     acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
218 
219     acc2 = vcmulq(vecA2, vecB);
220     acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
221 
222     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
223     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
224     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
225     /*
226      * Return to application
227      */
228     return (ARM_MATH_SUCCESS);
229 }
230 
231 
232 
233 
arm_mat_cmplx_mult_f16_4x4_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)234 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_4x4_mve(
235     const arm_matrix_instance_f16 * pSrcA,
236     const arm_matrix_instance_f16 * pSrcB,
237     arm_matrix_instance_f16 * pDst)
238 {
239     const uint16_t   MATRIX_DIM = 4;
240     float16_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
241     float16_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
242     float16_t       *pOut = pDst->pData;   /* output data matrix pointer */
243     uint16x8_t     vecColBOffs0;
244     float16_t       *pInA0 = pInA;
245     float16_t       *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM;
246     float16_t       *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM;
247     float16_t       *pInA3 = pInA2 + CMPLX_DIM * MATRIX_DIM;
248     f16x8_t        acc0, acc1, acc2, acc3;
249     f16x8_t        vecB, vecA;
250     static const uint16_t offsetB0[8] = { 0, 1,
251         MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
252         2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1,
253         3 * MATRIX_DIM * CMPLX_DIM, 3 * MATRIX_DIM * CMPLX_DIM + 1
254     };
255 
256     vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
257 
258     pInB = (float16_t const *)pSrcB->pData;
259 
260     vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
261 
262     vecA = vldrhq_f16(pInA0);
263     acc0 = vcmulq(vecA, vecB);
264     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
265 
266     vecA = vldrhq_f16(pInA1);
267     acc1 = vcmulq(vecA, vecB);
268     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
269 
270     vecA = vldrhq_f16(pInA2);
271     acc2 = vcmulq(vecA, vecB);
272     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
273 
274     vecA = vldrhq_f16(pInA3);
275     acc3 = vcmulq(vecA, vecB);
276     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
277 
278 
279     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
280     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
281     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
282     mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
283     pOut += CMPLX_DIM;
284     /*
285      * move to next B column
286      */
287     pInB = pInB + CMPLX_DIM;
288 
289     vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
290 
291     vecA = vldrhq_f16(pInA0);
292     acc0 = vcmulq(vecA, vecB);
293     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
294 
295     vecA = vldrhq_f16(pInA1);
296     acc1 = vcmulq(vecA, vecB);
297     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
298 
299     vecA = vldrhq_f16(pInA2);
300     acc2 = vcmulq(vecA, vecB);
301     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
302 
303     vecA = vldrhq_f16(pInA3);
304     acc3 = vcmulq(vecA, vecB);
305     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
306 
307 
308     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
309     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
310     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
311     mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
312     pOut += CMPLX_DIM;
313     /*
314      * move to next B column
315      */
316     pInB = pInB + CMPLX_DIM;
317 
318     vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
319 
320     vecA = vldrhq_f16(pInA0);
321     acc0 = vcmulq(vecA, vecB);
322     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
323 
324     vecA = vldrhq_f16(pInA1);
325     acc1 = vcmulq(vecA, vecB);
326     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
327 
328     vecA = vldrhq_f16(pInA2);
329     acc2 = vcmulq(vecA, vecB);
330     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
331 
332     vecA = vldrhq_f16(pInA3);
333     acc3 = vcmulq(vecA, vecB);
334     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
335 
336 
337     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
338     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
339     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
340     mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
341     pOut += CMPLX_DIM;
342     /*
343      * move to next B column
344      */
345     pInB = pInB + CMPLX_DIM;
346 
347     vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
348 
349     vecA = vldrhq_f16(pInA0);
350     acc0 = vcmulq(vecA, vecB);
351     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
352 
353     vecA = vldrhq_f16(pInA1);
354     acc1 = vcmulq(vecA, vecB);
355     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
356 
357     vecA = vldrhq_f16(pInA2);
358     acc2 = vcmulq(vecA, vecB);
359     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
360 
361     vecA = vldrhq_f16(pInA3);
362     acc3 = vcmulq(vecA, vecB);
363     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
364 
365 
366     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
367     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
368     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
369     mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
370     /*
371      * Return to application
372      */
373     return (ARM_MATH_SUCCESS);
374 }
375 
376 
377 
arm_mat_cmplx_mult_f16(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)378 arm_status arm_mat_cmplx_mult_f16(
379   const arm_matrix_instance_f16 * pSrcA,
380   const arm_matrix_instance_f16 * pSrcB,
381   arm_matrix_instance_f16 * pDst)
382 {
383     float16_t const *pInB = (float16_t const *) pSrcB->pData;   /* input data matrix pointer B */
384     float16_t const *pInA = (float16_t const *) pSrcA->pData;   /* input data matrix pointer A */
385     float16_t *pOut = pDst->pData;  /* output data matrix pointer */
386     float16_t *px;              /* Temporary output data matrix pointer */
387     uint16_t  numRowsA = pSrcA->numRows;    /* number of rows of input matrix A    */
388     uint16_t  numColsB = pSrcB->numCols;    /* number of columns of input matrix B */
389     uint16_t  numColsA = pSrcA->numCols;    /* number of columns of input matrix A */
390     uint16_t  col, i = 0U, row = numRowsA;  /* loop counters */
391     arm_status status;          /* status of matrix multiplication */
392     uint16x8_t vecOffs, vecColBOffs;
393     uint32_t  blkCnt,rowCnt;           /* loop counters */
394 
395     #ifdef ARM_MATH_MATRIX_CHECK
396 
397   /* Check for matrix mismatch condition */
398 if ((pSrcA->numCols != pSrcB->numRows) ||
399       (pSrcA->numRows != pDst->numRows)  ||
400       (pSrcB->numCols != pDst->numCols)    )
401   {
402     /* Set status as ARM_MATH_SIZE_MISMATCH */
403     status = ARM_MATH_SIZE_MISMATCH;
404   }
405   else
406 
407 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
408 
409   {
410 
411     /*
412      * small squared matrix specialized routines
413      */
414     if (numRowsA == numColsB && numColsB == numColsA)
415     {
416         if (numRowsA == 1)
417         {
418             pOut[0] = pInA[0] * pInB[0] - pInA[1] * pInB[1];
419             pOut[1] = pInA[0] * pInB[1] + pInA[1] * pInB[0];
420             return (ARM_MATH_SUCCESS);
421         }
422         else if  (numRowsA == 2)
423             return arm_mat_cmplx_mult_f16_2x2_mve(pSrcA, pSrcB, pDst);
424         else if (numRowsA == 3)
425             return arm_mat_cmplx_mult_f16_3x3_mve(pSrcA, pSrcB, pDst);
426         else if (numRowsA == 4)
427             return arm_mat_cmplx_mult_f16_4x4_mve(pSrcA, pSrcB, pDst);
428     }
429 
430     vecColBOffs[0] = 0;
431     vecColBOffs[1] = 1;
432     vecColBOffs[2] = numColsB * CMPLX_DIM;
433     vecColBOffs[3] = (numColsB * CMPLX_DIM) + 1;
434     vecColBOffs[4] = 2*numColsB * CMPLX_DIM;
435     vecColBOffs[5] = 2*(numColsB * CMPLX_DIM) + 1;
436     vecColBOffs[6] = 3*numColsB * CMPLX_DIM;
437     vecColBOffs[7] = 3*(numColsB * CMPLX_DIM) + 1;
438 
439     /*
440      * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
441      */
442 
443     /*
444      * row loop
445      */
446     rowCnt = row >> 2;
447     while (rowCnt > 0u)
448     {
449         /*
450          * Output pointer is set to starting address of the row being processed
451          */
452         px = pOut + i * CMPLX_DIM;
453         i = i + 4 * numColsB;
454         /*
455          * For every row wise process, the column loop counter is to be initiated
456          */
457         col = numColsB;
458         /*
459          * For every row wise process, the pInB pointer is set
460          * to the starting address of the pSrcB data
461          */
462         pInB = (float16_t const *) pSrcB->pData;
463         /*
464          * column loop
465          */
466         while (col > 0u)
467         {
468             /*
469              * generate 4 columns elements
470              */
471             /*
472              * Matrix A columns number of MAC operations are to be performed
473              */
474 
475             float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
476             float16_t const *pInA0 = pInA;
477             float16_t const *pInA1 = pInA0 + numColsA * CMPLX_DIM;
478             float16_t const *pInA2 = pInA1 + numColsA * CMPLX_DIM;
479             float16_t const *pInA3 = pInA2 + numColsA * CMPLX_DIM;
480             f16x8_t acc0, acc1, acc2, acc3;
481 
482             acc0 = vdupq_n_f16(0.0f16);
483             acc1 = vdupq_n_f16(0.0f16);
484             acc2 = vdupq_n_f16(0.0f16);
485             acc3 = vdupq_n_f16(0.0f16);
486 
487             pSrcA0Vec = (float16_t const *) pInA0;
488             pSrcA1Vec = (float16_t const *) pInA1;
489             pSrcA2Vec = (float16_t const *) pInA2;
490             pSrcA3Vec = (float16_t const *) pInA3;
491 
492             vecOffs = vecColBOffs;
493 
494             /*
495              * process 1 x 4 block output
496              */
497             blkCnt = (numColsA * CMPLX_DIM) >> 3;
498             while (blkCnt > 0U)
499             {
500                 f16x8_t vecB, vecA;
501 
502                 vecB = vldrhq_gather_shifted_offset_f16(pInB, vecOffs);
503                 /*
504                  * move Matrix B read offsets, 4 rows down
505                  */
506                 vecOffs = vaddq_n_u16(vecOffs , (uint16_t) (numColsB * 4 * CMPLX_DIM));
507 
508                 vecA = vld1q(pSrcA0Vec);  pSrcA0Vec += 8;
509                 acc0 = vcmlaq(acc0, vecA, vecB);
510                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
511 
512                 vecA = vld1q(pSrcA1Vec);  pSrcA1Vec += 8;
513                 acc1 = vcmlaq(acc1, vecA, vecB);
514                 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
515 
516                 vecA = vld1q(pSrcA2Vec);  pSrcA2Vec += 8;
517                 acc2 = vcmlaq(acc2, vecA, vecB);
518                 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
519 
520                 vecA = vld1q(pSrcA3Vec);  pSrcA3Vec += 8;
521                 acc3 = vcmlaq(acc3, vecA, vecB);
522                 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
523 
524                 blkCnt--;
525             }
526             /*
527              * Unsupported addressing mode compiler crash
528              */
529             /*
530              * tail
531              * (will be merged thru tail predication)
532              */
533             blkCnt = (numColsA * CMPLX_DIM) & 7;
534             if (blkCnt > 0U)
535             {
536                 mve_pred16_t p0 = vctp16q(blkCnt);
537                 f16x8_t vecB, vecA;
538 
539                 vecB = vldrhq_gather_shifted_offset_z_f16(pInB, vecOffs, p0);
540                 /*
541                  * move Matrix B read offsets, 4 rows down
542                  */
543                 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM));
544 
545                 vecA = vld1q(pSrcA0Vec);
546                 acc0 = vcmlaq(acc0, vecA, vecB);
547                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
548 
549                 vecA = vld1q(pSrcA1Vec);
550                 acc1 = vcmlaq(acc1, vecA, vecB);
551                 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
552 
553                 vecA = vld1q(pSrcA2Vec);
554                 acc2 = vcmlaq(acc2, vecA, vecB);
555                 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
556 
557                 vecA = vld1q(pSrcA3Vec);
558                 acc3 = vcmlaq(acc3, vecA, vecB);
559                 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
560 
561             }
562 
563 
564             mve_cmplx_sum_intra_vec_f16(acc0, &px[0 * CMPLX_DIM * numColsB + 0]);
565             mve_cmplx_sum_intra_vec_f16(acc1, &px[1 * CMPLX_DIM * numColsB + 0]);
566             mve_cmplx_sum_intra_vec_f16(acc2, &px[2 * CMPLX_DIM * numColsB + 0]);
567             mve_cmplx_sum_intra_vec_f16(acc3, &px[3 * CMPLX_DIM * numColsB + 0]);
568 
569             px += CMPLX_DIM;
570             /*
571              * Decrement the column loop counter
572              */
573             col--;
574             /*
575              * Update the pointer pInB to point to the  starting address of the next column
576              */
577             pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
578         }
579 
580         /*
581          * Update the pointer pInA to point to the  starting address of the next row
582          */
583         pInA += (numColsA * 4) * CMPLX_DIM;
584         /*
585          * Decrement the row loop counter
586          */
587         rowCnt --;
588 
589     }
590 
591     rowCnt = row & 3;
592     while (rowCnt > 0u)
593     {
594            /*
595          * Output pointer is set to starting address of the row being processed
596          */
597         px = pOut + i * CMPLX_DIM;
598         i = i + numColsB;
599         /*
600          * For every row wise process, the column loop counter is to be initiated
601          */
602         col = numColsB;
603         /*
604          * For every row wise process, the pInB pointer is set
605          * to the starting address of the pSrcB data
606          */
607         pInB = (float16_t const *) pSrcB->pData;
608         /*
609          * column loop
610          */
611         while (col > 0u)
612         {
613             /*
614              * generate 4 columns elements
615              */
616             /*
617              * Matrix A columns number of MAC operations are to be performed
618              */
619 
620             float16_t const *pSrcA0Vec;
621             float16_t const *pInA0 = pInA;
622             f16x8_t acc0;
623 
624             acc0 = vdupq_n_f16(0.0f16);
625 
626             pSrcA0Vec = (float16_t const *) pInA0;
627 
628             vecOffs = vecColBOffs;
629 
630             /*
631              * process 1 x 4 block output
632              */
633             blkCnt = (numColsA * CMPLX_DIM) >> 3;
634             while (blkCnt > 0U)
635             {
636                 f16x8_t vecB, vecA;
637 
638                 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
639                 /*
640                  * move Matrix B read offsets, 4 rows down
641                  */
642                 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (4*numColsB * CMPLX_DIM));
643 
644                 vecA = vld1q(pSrcA0Vec);
645                 pSrcA0Vec += 8;
646                 acc0 = vcmlaq(acc0, vecA, vecB);
647                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
648 
649 
650                 blkCnt--;
651             }
652 
653 
654             /*
655              * tail
656              */
657             blkCnt = (numColsA * CMPLX_DIM) & 7;
658             if (blkCnt > 0U)
659             {
660                 mve_pred16_t p0 = vctp16q(blkCnt);
661                 f16x8_t vecB, vecA;
662 
663                 vecB = vldrhq_gather_shifted_offset_z(pInB, vecOffs, p0);
664 
665                 vecA = vld1q(pSrcA0Vec);
666                 acc0 = vcmlaq(acc0, vecA, vecB);
667                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
668 
669             }
670 
671             mve_cmplx_sum_intra_vec_f16(acc0, &px[0]);
672 
673 
674             px += CMPLX_DIM;
675             /*
676              * Decrement the column loop counter
677              */
678             col--;
679             /*
680              * Update the pointer pInB to point to the  starting address of the next column
681              */
682             pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
683         }
684 
685         /*
686          * Update the pointer pInA to point to the  starting address of the next row
687          */
688         pInA += numColsA  * CMPLX_DIM;
689         rowCnt--;
690     }
691 
692     /*
693      * set status as ARM_MATH_SUCCESS
694      */
695     status = ARM_MATH_SUCCESS;
696  }
697     /*
698      * Return to application
699      */
700     return (status);
701 }
702 #else
703 
arm_mat_cmplx_mult_f16(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)704 arm_status arm_mat_cmplx_mult_f16(
705   const arm_matrix_instance_f16 * pSrcA,
706   const arm_matrix_instance_f16 * pSrcB,
707         arm_matrix_instance_f16 * pDst)
708 {
709   float16_t *pIn1 = pSrcA->pData;                /* Input data matrix pointer A */
710   float16_t *pIn2 = pSrcB->pData;                /* Input data matrix pointer B */
711   float16_t *pInA = pSrcA->pData;                /* Input data matrix pointer A */
712   float16_t *pOut = pDst->pData;                 /* Output data matrix pointer */
713   float16_t *px;                                 /* Temporary output data matrix pointer */
714   uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
715   uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
716   uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
717   _Float16 sumReal, sumImag;                    /* Accumulator */
718   _Float16 a1, b1, c1, d1;
719   uint32_t col, i = 0U, j, row = numRowsA, colCnt; /* loop counters */
720   arm_status status;                             /* status of matrix multiplication */
721 
722 #if defined (ARM_MATH_LOOPUNROLL)
723   _Float16 a0, b0, c0, d0;
724 #endif
725 
726 #ifdef ARM_MATH_MATRIX_CHECK
727 
728   /* Check for matrix mismatch condition */
729   if ((pSrcA->numCols != pSrcB->numRows) ||
730       (pSrcA->numRows != pDst->numRows)  ||
731       (pSrcB->numCols != pDst->numCols)    )
732   {
733     /* Set status as ARM_MATH_SIZE_MISMATCH */
734     status = ARM_MATH_SIZE_MISMATCH;
735   }
736   else
737 
738 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
739 
740   {
741     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
742     /* row loop */
743     do
744     {
745       /* Output pointer is set to starting address of the row being processed */
746       px = pOut + 2 * i;
747 
748       /* For every row wise process, the column loop counter is to be initiated */
749       col = numColsB;
750 
751       /* For every row wise process, the pIn2 pointer is set
752        ** to the starting address of the pSrcB data */
753       pIn2 = pSrcB->pData;
754 
755       j = 0U;
756 
757       /* column loop */
758       do
759       {
760         /* Set the variable sum, that acts as accumulator, to zero */
761         sumReal = 0.0f16;
762         sumImag = 0.0f16;
763 
764         /* Initiate pointer pIn1 to point to starting address of column being processed */
765         pIn1 = pInA;
766 
767 #if defined (ARM_MATH_LOOPUNROLL)
768 
769         /* Apply loop unrolling and compute 4 MACs simultaneously. */
770         colCnt = numColsA >> 2U;
771 
772         /* matrix multiplication */
773         while (colCnt > 0U)
774         {
775 
776           /* Reading real part of complex matrix A */
777           a0 = *pIn1;
778 
779           /* Reading real part of complex matrix B */
780           c0 = *pIn2;
781 
782           /* Reading imaginary part of complex matrix A */
783           b0 = *(pIn1 + 1U);
784 
785           /* Reading imaginary part of complex matrix B */
786           d0 = *(pIn2 + 1U);
787 
788           /* Multiply and Accumlates */
789           sumReal += a0 * c0;
790           sumImag += b0 * c0;
791 
792           /* update pointers */
793           pIn1 += 2U;
794           pIn2 += 2 * numColsB;
795 
796           /* Multiply and Accumlates */
797           sumReal -= b0 * d0;
798           sumImag += a0 * d0;
799 
800           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
801 
802           /* read real and imag values from pSrcA and pSrcB buffer */
803           a1 = *(pIn1     );
804           c1 = *(pIn2     );
805           b1 = *(pIn1 + 1U);
806           d1 = *(pIn2 + 1U);
807 
808           /* Multiply and Accumlates */
809           sumReal += a1 * c1;
810           sumImag += b1 * c1;
811 
812           /* update pointers */
813           pIn1 += 2U;
814           pIn2 += 2 * numColsB;
815 
816           /* Multiply and Accumlates */
817           sumReal -= b1 * d1;
818           sumImag += a1 * d1;
819 
820           a0 = *(pIn1     );
821           c0 = *(pIn2     );
822           b0 = *(pIn1 + 1U);
823           d0 = *(pIn2 + 1U);
824 
825           /* Multiply and Accumlates */
826           sumReal += a0 * c0;
827           sumImag += b0 * c0;
828 
829           /* update pointers */
830           pIn1 += 2U;
831           pIn2 += 2 * numColsB;
832 
833           /* Multiply and Accumlates */
834           sumReal -= b0 * d0;
835           sumImag += a0 * d0;
836 
837           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
838 
839           a1 = *(pIn1     );
840           c1 = *(pIn2     );
841           b1 = *(pIn1 + 1U);
842           d1 = *(pIn2 + 1U);
843 
844           /* Multiply and Accumlates */
845           sumReal += a1 * c1;
846           sumImag += b1 * c1;
847 
848           /* update pointers */
849           pIn1 += 2U;
850           pIn2 += 2 * numColsB;
851 
852           /* Multiply and Accumlates */
853           sumReal -= b1 * d1;
854           sumImag += a1 * d1;
855 
856           /* Decrement loop count */
857           colCnt--;
858         }
859 
860         /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
861          ** No loop unrolling is used. */
862         colCnt = numColsA % 0x4U;
863 
864 #else
865 
866         /* Initialize blkCnt with number of samples */
867         colCnt = numColsA;
868 
869 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
870 
871         while (colCnt > 0U)
872         {
873           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
874           a1 = *(pIn1     );
875           c1 = *(pIn2     );
876           b1 = *(pIn1 + 1U);
877           d1 = *(pIn2 + 1U);
878 
879           /* Multiply and Accumlates */
880           sumReal += a1 * c1;
881           sumImag += b1 * c1;
882 
883           /* update pointers */
884           pIn1 += 2U;
885           pIn2 += 2 * numColsB;
886 
887           /* Multiply and Accumlates */
888           sumReal -= b1 * d1;
889           sumImag += a1 * d1;
890 
891           /* Decrement loop counter */
892           colCnt--;
893         }
894 
895         /* Store result in destination buffer */
896         *px++ = sumReal;
897         *px++ = sumImag;
898 
899         /* Update pointer pIn2 to point to starting address of next column */
900         j++;
901         pIn2 = pSrcB->pData + 2U * j;
902 
903         /* Decrement column loop counter */
904         col--;
905 
906       } while (col > 0U);
907 
908       /* Update pointer pInA to point to starting address of next row */
909       i = i + numColsB;
910       pInA = pInA + 2 * numColsA;
911 
912       /* Decrement row loop counter */
913       row--;
914 
915     } while (row > 0U);
916 
917     /* Set status as ARM_MATH_SUCCESS */
918     status = ARM_MATH_SUCCESS;
919   }
920 
921   /* Return to application */
922   return (status);
923 }
924 
925 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
926 
927 /**
928   @} end of MatrixMult group
929  */
930 
931 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
932 
933