• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_ldl_f32.c
4  * Description:  Floating-point LDL decomposition
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 
32 
33 
34 
35 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
36 
37 
38 /// @private
39 #define SWAP_ROWS_F32(A,i,j)                 \
40   {                                      \
41     int cnt = n;                         \
42                                          \
43     for(int w=0;w < n; w+=4)             \
44     {                                    \
45        f32x4_t tmpa,tmpb;                \
46        mve_pred16_t p0 = vctp32q(cnt);   \
47                                          \
48        tmpa=vldrwq_z_f32(&A[i*n + w],p0);\
49        tmpb=vldrwq_z_f32(&A[j*n + w],p0);\
50                                          \
51        vstrwq_p(&A[i*n + w], tmpb, p0);  \
52        vstrwq_p(&A[j*n + w], tmpa, p0);  \
53                                          \
54        cnt -= 4;                         \
55     }                                    \
56   }
57 
58 /// @private
59 #define SWAP_COLS_F32(A,i,j)     \
60   for(int w=0;w < n; w++)    \
61   {                          \
62      float32_t tmp;          \
63      tmp = A[w*n + i];       \
64      A[w*n + i] = A[w*n + j];\
65      A[w*n + j] = tmp;       \
66   }
67 
68 /**
69   @ingroup groupMatrix
70  */
71 
72 /**
73   @addtogroup MatrixChol
74   @{
75  */
76 
77 /**
78    * @brief Floating-point LDL^t decomposition of positive semi-definite matrix.
79    * @param[in]  pSrc   points to the instance of the input floating-point matrix structure.
80    * @param[out] pl   points to the instance of the output floating-point triangular matrix structure.
81    * @param[out] pd   points to the instance of the output floating-point diagonal matrix structure.
82    * @param[out] pp   points to the instance of the output floating-point permutation vector.
83    * @return The function returns ARM_MATH_SIZE_MISMATCH, if the dimensions do not match.
84    * @return        execution status
85                    - \ref ARM_MATH_SUCCESS       : Operation successful
86                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
87                    - \ref ARM_MATH_DECOMPOSITION_FAILURE      : Input matrix cannot be decomposed
88    * @par
89    *  Computes the LDL^t decomposition of a matrix A such that P A P^t = L D L^t.
90    */
arm_mat_ldlt_f32(const arm_matrix_instance_f32 * pSrc,arm_matrix_instance_f32 * pl,arm_matrix_instance_f32 * pd,uint16_t * pp)91 arm_status arm_mat_ldlt_f32(
92   const arm_matrix_instance_f32 * pSrc,
93   arm_matrix_instance_f32 * pl,
94   arm_matrix_instance_f32 * pd,
95   uint16_t * pp)
96 {
97 
98   arm_status status;                             /* status of matrix inverse */
99 
100 
101 #ifdef ARM_MATH_MATRIX_CHECK
102 
103   /* Check for matrix mismatch condition */
104   if ((pSrc->numRows != pSrc->numCols) ||
105       (pl->numRows != pl->numCols) ||
106       (pd->numRows != pd->numCols) ||
107       (pl->numRows != pd->numRows)   )
108   {
109     /* Set status as ARM_MATH_SIZE_MISMATCH */
110     status = ARM_MATH_SIZE_MISMATCH;
111   }
112   else
113 
114 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
115 
116   {
117 
118     const int n=pSrc->numRows;
119     int fullRank = 1, diag,k;
120     float32_t *pA;
121 
122     memcpy(pl->pData,pSrc->pData,n*n*sizeof(float32_t));
123     pA = pl->pData;
124 
125     int cnt = n;
126     uint16x8_t vecP;
127 
128     for(int k=0;k < n; k+=8)
129     {
130       mve_pred16_t p0;
131       p0 = vctp16q(cnt);
132 
133       vecP = vidupq_u16((uint16_t)k, 1);
134 
135       vstrhq_p(&pp[k], vecP, p0);
136 
137       cnt -= 8;
138     }
139 
140 
141     for(k=0;k < n; k++)
142     {
143         /* Find pivot */
144         float32_t m=F32_MIN,a;
145         int j=k;
146 
147 
148         for(int r=k;r<n;r++)
149         {
150            if (pA[r*n+r] > m)
151            {
152              m = pA[r*n+r];
153              j = r;
154            }
155         }
156 
157         if(j != k)
158         {
159           SWAP_ROWS_F32(pA,k,j);
160           SWAP_COLS_F32(pA,k,j);
161         }
162 
163 
164         pp[k] = j;
165 
166         a = pA[k*n+k];
167 
168         if (fabs(a) < 1.0e-8)
169         {
170 
171             fullRank = 0;
172             break;
173         }
174 
175         float32_t invA;
176 
177         invA = 1.0f / a;
178 
179         int32x4_t vecOffs;
180         int w;
181         vecOffs = vidupq_u32((uint32_t)0, 1);
182         vecOffs = vmulq_n_s32(vecOffs,n);
183 
184         for(w=k+1; w<n; w+=4)
185         {
186           int cnt = n - k - 1;
187 
188           f32x4_t vecX;
189 
190           f32x4_t vecA;
191           f32x4_t vecW0,vecW1, vecW2, vecW3;
192 
193           mve_pred16_t p0;
194 
195           vecW0 = vdupq_n_f32(pA[(w + 0)*n+k]);
196           vecW1 = vdupq_n_f32(pA[(w + 1)*n+k]);
197           vecW2 = vdupq_n_f32(pA[(w + 2)*n+k]);
198           vecW3 = vdupq_n_f32(pA[(w + 3)*n+k]);
199 
200           for(int x=k+1;x<n;x += 4)
201           {
202              p0 = vctp32q(cnt);
203 
204              //pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * (pA[x*n+k] * invA);
205 
206 
207              vecX = vldrwq_gather_shifted_offset_z_f32(&pA[x*n+k], vecOffs, p0);
208              vecX = vmulq_m_n_f32(vuninitializedq_f32(),vecX,invA,p0);
209 
210 
211              vecA = vldrwq_z_f32(&pA[(w + 0)*n+x],p0);
212              vecA = vfmsq_m(vecA, vecW0, vecX, p0);
213              vstrwq_p(&pA[(w + 0)*n+x], vecA, p0);
214 
215              vecA = vldrwq_z_f32(&pA[(w + 1)*n+x],p0);
216              vecA = vfmsq_m(vecA, vecW1, vecX, p0);
217              vstrwq_p(&pA[(w + 1)*n+x], vecA, p0);
218 
219              vecA = vldrwq_z_f32(&pA[(w + 2)*n+x],p0);
220              vecA = vfmsq_m(vecA, vecW2, vecX, p0);
221              vstrwq_p(&pA[(w + 2)*n+x], vecA, p0);
222 
223              vecA = vldrwq_z_f32(&pA[(w + 3)*n+x],p0);
224              vecA = vfmsq_m(vecA, vecW3, vecX, p0);
225              vstrwq_p(&pA[(w + 3)*n+x], vecA, p0);
226 
227              cnt -= 4;
228           }
229         }
230 
231         for(; w<n; w++)
232         {
233           int cnt = n - k - 1;
234 
235           f32x4_t vecA,vecX,vecW;
236 
237 
238           mve_pred16_t p0;
239 
240           vecW = vdupq_n_f32(pA[w*n+k]);
241 
242           for(int x=k+1;x<n;x += 4)
243           {
244              p0 = vctp32q(cnt);
245 
246              //pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * (pA[x*n+k] * invA);
247 
248              vecA = vldrwq_z_f32(&pA[w*n+x],p0);
249 
250              vecX = vldrwq_gather_shifted_offset_z_f32(&pA[x*n+k], vecOffs, p0);
251              vecX = vmulq_m_n_f32(vuninitializedq_f32(),vecX,invA,p0);
252 
253              vecA = vfmsq_m(vecA, vecW, vecX, p0);
254 
255              vstrwq_p(&pA[w*n+x], vecA, p0);
256 
257              cnt -= 4;
258           }
259         }
260 
261         for(int w=k+1;w<n;w++)
262         {
263                pA[w*n+k] = pA[w*n+k] * invA;
264         }
265 
266 
267 
268     }
269 
270 
271 
272     diag=k;
273     if (!fullRank)
274     {
275       diag--;
276       for(int row=0; row < n;row++)
277       {
278         mve_pred16_t p0;
279         int cnt= n-k;
280         f32x4_t zero=vdupq_n_f32(0.0f);
281 
282         for(int col=k; col < n;col += 4)
283         {
284            p0 = vctp32q(cnt);
285 
286            vstrwq_p(&pl->pData[row*n+col], zero, p0);
287 
288            cnt -= 4;
289         }
290       }
291     }
292 
293     for(int row=0; row < n;row++)
294     {
295        mve_pred16_t p0;
296        int cnt= n-row-1;
297        f32x4_t zero=vdupq_n_f32(0.0f);
298 
299        for(int col=row+1; col < n;col+=4)
300        {
301          p0 = vctp32q(cnt);
302 
303          vstrwq_p(&pl->pData[row*n+col], zero, p0);
304 
305          cnt -= 4;
306        }
307     }
308 
309     for(int d=0; d < diag;d++)
310     {
311       pd->pData[d*n+d] = pl->pData[d*n+d];
312       pl->pData[d*n+d] = 1.0;
313     }
314 
315     status = ARM_MATH_SUCCESS;
316 
317   }
318 
319 
320   /* Return to application */
321   return (status);
322 }
323 #else
324 
325 /// @private
326 #define SWAP_ROWS_F32(A,i,j)     \
327   for(int w=0;w < n; w++)    \
328   {                          \
329      float32_t tmp;          \
330      tmp = A[i*n + w];       \
331      A[i*n + w] = A[j*n + w];\
332      A[j*n + w] = tmp;       \
333   }
334 
335 /// @private
336 #define SWAP_COLS_F32(A,i,j)     \
337   for(int w=0;w < n; w++)    \
338   {                          \
339      float32_t tmp;          \
340      tmp = A[w*n + i];       \
341      A[w*n + i] = A[w*n + j];\
342      A[w*n + j] = tmp;       \
343   }
344 
345 /**
346   @ingroup groupMatrix
347  */
348 
349 /**
350   @addtogroup MatrixChol
351   @{
352  */
353 
354 /**
355    * @brief Floating-point LDL^t decomposition of positive semi-definite matrix.
356    * @param[in]  pSrc   points to the instance of the input floating-point matrix structure.
357    * @param[out] pl   points to the instance of the output floating-point triangular matrix structure.
358    * @param[out] pd   points to the instance of the output floating-point diagonal matrix structure.
359    * @param[out] pp   points to the instance of the output floating-point permutation vector.
360    * @return The function returns ARM_MATH_SIZE_MISMATCH, if the dimensions do not match.
361    * @return        execution status
362                    - \ref ARM_MATH_SUCCESS       : Operation successful
363                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
364                    - \ref ARM_MATH_DECOMPOSITION_FAILURE      : Input matrix cannot be decomposed
365    * @par
366    *  Computes the LDL^t decomposition of a matrix A such that P A P^t = L D L^t.
367    */
arm_mat_ldlt_f32(const arm_matrix_instance_f32 * pSrc,arm_matrix_instance_f32 * pl,arm_matrix_instance_f32 * pd,uint16_t * pp)368 arm_status arm_mat_ldlt_f32(
369   const arm_matrix_instance_f32 * pSrc,
370   arm_matrix_instance_f32 * pl,
371   arm_matrix_instance_f32 * pd,
372   uint16_t * pp)
373 {
374 
375   arm_status status;                             /* status of matrix inverse */
376 
377 
378 #ifdef ARM_MATH_MATRIX_CHECK
379 
380   /* Check for matrix mismatch condition */
381   if ((pSrc->numRows != pSrc->numCols) ||
382       (pl->numRows != pl->numCols) ||
383       (pd->numRows != pd->numCols) ||
384       (pl->numRows != pd->numRows)   )
385   {
386     /* Set status as ARM_MATH_SIZE_MISMATCH */
387     status = ARM_MATH_SIZE_MISMATCH;
388   }
389   else
390 
391 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
392 
393   {
394 
395     const int n=pSrc->numRows;
396     int fullRank = 1, diag,k;
397     float32_t *pA;
398 
399     memcpy(pl->pData,pSrc->pData,n*n*sizeof(float32_t));
400     pA = pl->pData;
401 
402     for(int k=0;k < n; k++)
403     {
404       pp[k] = k;
405     }
406 
407 
408     for(k=0;k < n; k++)
409     {
410         /* Find pivot */
411         float32_t m=F32_MIN,a;
412         int j=k;
413 
414 
415         for(int r=k;r<n;r++)
416         {
417            if (pA[r*n+r] > m)
418            {
419              m = pA[r*n+r];
420              j = r;
421            }
422         }
423 
424         if(j != k)
425         {
426           SWAP_ROWS_F32(pA,k,j);
427           SWAP_COLS_F32(pA,k,j);
428         }
429 
430 
431         pp[k] = j;
432 
433         a = pA[k*n+k];
434 
435         if (fabs(a) < 1.0e-8)
436         {
437 
438             fullRank = 0;
439             break;
440         }
441 
442         for(int w=k+1;w<n;w++)
443         {
444           for(int x=k+1;x<n;x++)
445           {
446              pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * pA[x*n+k] / a;
447           }
448         }
449 
450         for(int w=k+1;w<n;w++)
451         {
452                pA[w*n+k] = pA[w*n+k] / a;
453         }
454 
455 
456 
457     }
458 
459 
460 
461     diag=k;
462     if (!fullRank)
463     {
464       diag--;
465       for(int row=0; row < n;row++)
466       {
467         for(int col=k; col < n;col++)
468         {
469            pl->pData[row*n+col]=0.0;
470         }
471       }
472     }
473 
474     for(int row=0; row < n;row++)
475     {
476        for(int col=row+1; col < n;col++)
477        {
478          pl->pData[row*n+col] = 0.0;
479        }
480     }
481 
482     for(int d=0; d < diag;d++)
483     {
484       pd->pData[d*n+d] = pl->pData[d*n+d];
485       pl->pData[d*n+d] = 1.0;
486     }
487 
488     status = ARM_MATH_SUCCESS;
489 
490   }
491 
492 
493   /* Return to application */
494   return (status);
495 }
496 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
497 
498 /**
499   @} end of MatrixChol group
500  */
501