• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_logsumexp_f32.c
4  * Description:  LogSumExp
5  *
6  *
7  * Target Processor: Cortex-M and Cortex-A cores
8  * -------------------------------------------------------------------- */
9 /*
10  * Copyright (C) 2010-2019 ARM Limited or its affiliates. All rights reserved.
11  *
12  * SPDX-License-Identifier: Apache-2.0
13  *
14  * Licensed under the Apache License, Version 2.0 (the License); you may
15  * not use this file except in compliance with the License.
16  * You may obtain a copy of the License at
17  *
18  * www.apache.org/licenses/LICENSE-2.0
19  *
20  * Unless required by applicable law or agreed to in writing, software
21  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
22  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23  * See the License for the specific language governing permissions and
24  * limitations under the License.
25  */
26 
27 #include "arm_math.h"
28 #include <limits.h>
29 #include <math.h>
30 
31 
32 /**
33  * @addtogroup groupStats
34  * @{
35  */
36 
37 
38 /**
39  * @brief Computation of the LogSumExp
40  *
41  * In probabilistic computations, the dynamic of the probability values can be very
42  * wide because they come from gaussian functions.
43  * To avoid underflow and overflow issues, the values are represented by their log.
44  * In this representation, multiplying the original exp values is easy : their logs are added.
45  * But adding the original exp values is requiring some special handling and it is the
46  * goal of the LogSumExp function.
47  *
48  * If the values are x1...xn, the function is computing:
49  *
50  * ln(exp(x1) + ... + exp(xn)) and the computation is done in such a way that
51  * rounding issues are minimised.
52  *
53  * The max xm of the values is extracted and the function is computing:
54  * xm + ln(exp(x1 - xm) + ... + exp(xn - xm))
55  *
56  * @param[in]  *in         Pointer to an array of input values.
57  * @param[in]  blockSize   Number of samples in the input array.
58  * @return LogSumExp
59  *
60  */
61 
62 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
63 
64 #include "arm_helium_utils.h"
65 #include "arm_vec_math.h"
66 
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)67 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
68 {
69     float32_t       maxVal;
70     const float32_t *pIn;
71     int32_t         blkCnt;
72     float32_t       accum=0.0f;
73     float32_t       tmp;
74 
75 
76     arm_max_no_idx_f32((float32_t *) in, blockSize, &maxVal);
77 
78 
79     blkCnt = blockSize;
80     pIn = in;
81 
82 
83     f32x4_t         vSum = vdupq_n_f32(0.0f);
84     blkCnt = blockSize >> 2;
85     while(blkCnt > 0)
86     {
87         f32x4_t         vecIn = vld1q(pIn);
88         f32x4_t         vecExp;
89 
90         vecExp = vexpq_f32(vsubq_n_f32(vecIn, maxVal));
91 
92         vSum = vaddq_f32(vSum, vecExp);
93 
94         /*
95          * Decrement the blockSize loop counter
96          * Advance vector source and destination pointers
97          */
98         pIn += 4;
99         blkCnt --;
100     }
101 
102     /* sum + log */
103     accum = vecAddAcrossF32Mve(vSum);
104 
105     blkCnt = blockSize & 0x3;
106     while(blkCnt > 0)
107     {
108        tmp = *pIn++;
109        accum += expf(tmp - maxVal);
110        blkCnt--;
111 
112     }
113 
114     accum = maxVal + log(accum);
115 
116     return (accum);
117 }
118 
119 #else
120 #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
121 
122 #include "NEMath.h"
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)123 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
124 {
125     float32_t maxVal;
126     float32_t tmp;
127     float32x4_t tmpV, tmpVb;
128     float32x4_t maxValV;
129     uint32x4_t idxV;
130     float32x4_t accumV;
131     float32x2_t accumV2;
132 
133     const float32_t *pIn;
134     uint32_t blkCnt;
135     float32_t accum;
136 
137     pIn = in;
138 
139     blkCnt = blockSize;
140 
141     if (blockSize <= 3)
142     {
143       maxVal = *pIn++;
144       blkCnt--;
145 
146       while(blkCnt > 0)
147       {
148          tmp = *pIn++;
149 
150          if (tmp > maxVal)
151          {
152             maxVal = tmp;
153          }
154          blkCnt--;
155       }
156     }
157     else
158     {
159       maxValV = vld1q_f32(pIn);
160       pIn += 4;
161       blkCnt = (blockSize - 4) >> 2;
162 
163       while(blkCnt > 0)
164       {
165          tmpVb = vld1q_f32(pIn);
166          pIn += 4;
167 
168          idxV = vcgtq_f32(tmpVb, maxValV);
169          maxValV = vbslq_f32(idxV, tmpVb, maxValV );
170 
171          blkCnt--;
172       }
173 
174       accumV2 = vpmax_f32(vget_low_f32(maxValV),vget_high_f32(maxValV));
175       accumV2 = vpmax_f32(accumV2,accumV2);
176       maxVal = vget_lane_f32(accumV2, 0) ;
177 
178       blkCnt = (blockSize - 4) & 3;
179 
180       while(blkCnt > 0)
181       {
182          tmp = *pIn++;
183 
184          if (tmp > maxVal)
185          {
186             maxVal = tmp;
187          }
188          blkCnt--;
189       }
190 
191     }
192 
193 
194 
195     maxValV = vdupq_n_f32(maxVal);
196     pIn = in;
197     accum = 0;
198     accumV = vdupq_n_f32(0.0f);
199 
200     blkCnt = blockSize >> 2;
201 
202     while(blkCnt > 0)
203     {
204        tmpV = vld1q_f32(pIn);
205        pIn += 4;
206        tmpV = vsubq_f32(tmpV, maxValV);
207        tmpV = vexpq_f32(tmpV);
208        accumV = vaddq_f32(accumV, tmpV);
209 
210        blkCnt--;
211 
212     }
213     accumV2 = vpadd_f32(vget_low_f32(accumV),vget_high_f32(accumV));
214     accum = vget_lane_f32(accumV2, 0) + vget_lane_f32(accumV2, 1);
215 
216     blkCnt = blockSize & 0x3;
217     while(blkCnt > 0)
218     {
219        tmp = *pIn++;
220        accum += expf(tmp - maxVal);
221        blkCnt--;
222 
223     }
224 
225     accum = maxVal + logf(accum);
226 
227     return(accum);
228 }
229 #else
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)230 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
231 {
232     float32_t maxVal;
233     float32_t tmp;
234     const float32_t *pIn;
235     uint32_t blkCnt;
236     float32_t accum;
237 
238     pIn = in;
239     blkCnt = blockSize;
240 
241     maxVal = *pIn++;
242     blkCnt--;
243 
244     while(blkCnt > 0)
245     {
246        tmp = *pIn++;
247 
248        if (tmp > maxVal)
249        {
250           maxVal = tmp;
251        }
252        blkCnt--;
253 
254     }
255 
256     blkCnt = blockSize;
257     pIn = in;
258     accum = 0;
259     while(blkCnt > 0)
260     {
261        tmp = *pIn++;
262        accum += expf(tmp - maxVal);
263        blkCnt--;
264 
265     }
266     accum = maxVal + logf(accum);
267 
268     return(accum);
269 }
270 #endif
271 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
272 
273 /**
274  * @} end of groupStats group
275  */
276