• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_solve_lower_triangular_f32.c
4  * Description:  Solve linear system LT X = A with LT lower triangular matrix
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 /**
32   @ingroup groupMatrix
33  */
34 
35 
36 /**
37   @addtogroup MatrixInv
38   @{
39  */
40 
41 
42    /**
43    * @brief Solve LT . X = A where LT is a lower triangular matrix
44    * @param[in]  lt  The lower triangular matrix
45    * @param[in]  a  The matrix a
46    * @param[out] dst The solution X of LT . X = A
47    * @return The function returns ARM_MATH_SINGULAR, if the system can't be solved.
48    */
49 
50 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
51 
52 #include "arm_helium_utils.h"
53 
arm_mat_solve_lower_triangular_f32(const arm_matrix_instance_f32 * lt,const arm_matrix_instance_f32 * a,arm_matrix_instance_f32 * dst)54   arm_status arm_mat_solve_lower_triangular_f32(
55   const arm_matrix_instance_f32 * lt,
56   const arm_matrix_instance_f32 * a,
57   arm_matrix_instance_f32 * dst)
58   {
59   arm_status status;                             /* status of matrix inverse */
60 
61 
62 #ifdef ARM_MATH_MATRIX_CHECK
63 
64   /* Check for matrix mismatch condition */
65   if ((lt->numRows != lt->numCols) ||
66       (a->numRows != a->numCols) ||
67       (lt->numRows != a->numRows)   )
68   {
69     /* Set status as ARM_MATH_SIZE_MISMATCH */
70     status = ARM_MATH_SIZE_MISMATCH;
71   }
72   else
73 
74 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
75 
76   {
77     /* a1 b1 c1   x1 = a1
78           b2 c2   x2   a2
79              c3   x3   a3
80 
81     x3 = a3 / c3
82     x2 = (a2 - c2 x3) / b2
83 
84     */
85     int i,j,k,n;
86 
87     n = dst->numRows;
88 
89     float32_t *pX = dst->pData;
90     float32_t *pLT = lt->pData;
91     float32_t *pA = a->pData;
92 
93     float32_t *lt_row;
94     float32_t *a_col;
95 
96     float32_t invLT;
97 
98     f32x4_t vecA;
99     f32x4_t vecX;
100 
101     for(i=0; i < n ; i++)
102     {
103 
104       for(j=0; j+3 < n; j += 4)
105       {
106             vecA = vld1q_f32(&pA[i * n + j]);
107 
108             for(k=0; k < i; k++)
109             {
110                 vecX = vld1q_f32(&pX[n*k+j]);
111                 vecA = vfmsq(vecA,vdupq_n_f32(pLT[n*i + k]),vecX);
112             }
113 
114             if (pLT[n*i + i]==0.0f)
115             {
116               return(ARM_MATH_SINGULAR);
117             }
118 
119             invLT = 1.0f / pLT[n*i + i];
120             vecA = vmulq(vecA,vdupq_n_f32(invLT));
121             vst1q(&pX[i*n+j],vecA);
122 
123        }
124 
125        for(; j < n; j ++)
126        {
127             a_col = &pA[j];
128             lt_row = &pLT[n*i];
129 
130             float32_t tmp=a_col[i * n];
131 
132             for(k=0; k < i; k++)
133             {
134                 tmp -= lt_row[k] * pX[n*k+j];
135             }
136 
137             if (lt_row[i]==0.0f)
138             {
139               return(ARM_MATH_SINGULAR);
140             }
141             tmp = tmp / lt_row[i];
142             pX[i*n+j] = tmp;
143         }
144 
145     }
146     status = ARM_MATH_SUCCESS;
147 
148   }
149 
150   /* Return to application */
151   return (status);
152 }
153 #else
154 #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
arm_mat_solve_lower_triangular_f32(const arm_matrix_instance_f32 * lt,const arm_matrix_instance_f32 * a,arm_matrix_instance_f32 * dst)155   arm_status arm_mat_solve_lower_triangular_f32(
156   const arm_matrix_instance_f32 * lt,
157   const arm_matrix_instance_f32 * a,
158   arm_matrix_instance_f32 * dst)
159   {
160   arm_status status;                             /* status of matrix inverse */
161 
162 
163 #ifdef ARM_MATH_MATRIX_CHECK
164 
165   /* Check for matrix mismatch condition */
166   if ((lt->numRows != lt->numCols) ||
167       (a->numRows != a->numCols) ||
168       (lt->numRows != a->numRows)   )
169   {
170     /* Set status as ARM_MATH_SIZE_MISMATCH */
171     status = ARM_MATH_SIZE_MISMATCH;
172   }
173   else
174 
175 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
176 
177   {
178     /* a1 b1 c1   x1 = a1
179           b2 c2   x2   a2
180              c3   x3   a3
181 
182     x3 = a3 / c3
183     x2 = (a2 - c2 x3) / b2
184 
185     */
186     int i,j,k,n;
187 
188     n = dst->numRows;
189 
190     float32_t *pX = dst->pData;
191     float32_t *pLT = lt->pData;
192     float32_t *pA = a->pData;
193 
194     float32_t *lt_row;
195     float32_t *a_col;
196 
197     float32_t invLT;
198 
199     f32x4_t vecA;
200     f32x4_t vecX;
201 
202     for(i=0; i < n ; i++)
203     {
204 
205       for(j=0; j+3 < n; j += 4)
206       {
207             vecA = vld1q_f32(&pA[i * n + j]);
208 
209             for(k=0; k < i; k++)
210             {
211                 vecX = vld1q_f32(&pX[n*k+j]);
212                 vecA = vfmsq_f32(vecA,vdupq_n_f32(pLT[n*i + k]),vecX);
213             }
214 
215             if (pLT[n*i + i]==0.0f)
216             {
217               return(ARM_MATH_SINGULAR);
218             }
219 
220             invLT = 1.0f / pLT[n*i + i];
221             vecA = vmulq_f32(vecA,vdupq_n_f32(invLT));
222             vst1q_f32(&pX[i*n+j],vecA);
223 
224        }
225 
226        for(; j < n; j ++)
227        {
228             a_col = &pA[j];
229             lt_row = &pLT[n*i];
230 
231             float32_t tmp=a_col[i * n];
232 
233             for(k=0; k < i; k++)
234             {
235                 tmp -= lt_row[k] * pX[n*k+j];
236             }
237 
238             if (lt_row[i]==0.0f)
239             {
240               return(ARM_MATH_SINGULAR);
241             }
242             tmp = tmp / lt_row[i];
243             pX[i*n+j] = tmp;
244         }
245 
246     }
247     status = ARM_MATH_SUCCESS;
248 
249   }
250 
251   /* Return to application */
252   return (status);
253 }
254 #else
arm_mat_solve_lower_triangular_f32(const arm_matrix_instance_f32 * lt,const arm_matrix_instance_f32 * a,arm_matrix_instance_f32 * dst)255   arm_status arm_mat_solve_lower_triangular_f32(
256   const arm_matrix_instance_f32 * lt,
257   const arm_matrix_instance_f32 * a,
258   arm_matrix_instance_f32 * dst)
259   {
260   arm_status status;                             /* status of matrix inverse */
261 
262 
263 #ifdef ARM_MATH_MATRIX_CHECK
264   /* Check for matrix mismatch condition */
265   if ((lt->numRows != lt->numCols) ||
266       (a->numRows != a->numCols) ||
267       (lt->numRows != a->numRows)   )
268   {
269     /* Set status as ARM_MATH_SIZE_MISMATCH */
270     status = ARM_MATH_SIZE_MISMATCH;
271   }
272   else
273 
274 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
275 
276   {
277     /* a1 b1 c1   x1 = a1
278           b2 c2   x2   a2
279              c3   x3   a3
280 
281     x3 = a3 / c3
282     x2 = (a2 - c2 x3) / b2
283 
284     */
285     int i,j,k,n;
286 
287     n = dst->numRows;
288 
289     float32_t *pX = dst->pData;
290     float32_t *pLT = lt->pData;
291     float32_t *pA = a->pData;
292 
293     float32_t *lt_row;
294     float32_t *a_col;
295 
296     for(j=0; j < n; j ++)
297     {
298        a_col = &pA[j];
299 
300        for(i=0; i < n ; i++)
301        {
302             lt_row = &pLT[n*i];
303 
304             float32_t tmp=a_col[i * n];
305 
306             for(k=0; k < i; k++)
307             {
308                 tmp -= lt_row[k] * pX[n*k+j];
309             }
310 
311             if (lt_row[i]==0.0f)
312             {
313               return(ARM_MATH_SINGULAR);
314             }
315             tmp = tmp / lt_row[i];
316             pX[i*n+j] = tmp;
317        }
318 
319     }
320     status = ARM_MATH_SUCCESS;
321 
322   }
323 
324   /* Return to application */
325   return (status);
326 }
327 #endif /* #if defined(ARM_MATH_NEON) */
328 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
329 
330 /**
331   @} end of MatrixInv group
332  */
333