1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_logsumexp_f32.c
4 * Description: LogSumExp
5 *
6 *
7 * Target Processor: Cortex-M and Cortex-A cores
8 * -------------------------------------------------------------------- */
9 /*
10 * Copyright (C) 2010-2019 ARM Limited or its affiliates. All rights reserved.
11 *
12 * SPDX-License-Identifier: Apache-2.0
13 *
14 * Licensed under the Apache License, Version 2.0 (the License); you may
15 * not use this file except in compliance with the License.
16 * You may obtain a copy of the License at
17 *
18 * www.apache.org/licenses/LICENSE-2.0
19 *
20 * Unless required by applicable law or agreed to in writing, software
21 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
22 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 * See the License for the specific language governing permissions and
24 * limitations under the License.
25 */
26
27 #include "arm_math.h"
28 #include <limits.h>
29 #include <math.h>
30
31
32 /**
33 * @addtogroup groupStats
34 * @{
35 */
36
37
38 /**
39 * @brief Computation of the LogSumExp
40 *
41 * In probabilistic computations, the dynamic of the probability values can be very
42 * wide because they come from gaussian functions.
43 * To avoid underflow and overflow issues, the values are represented by their log.
44 * In this representation, multiplying the original exp values is easy : their logs are added.
45 * But adding the original exp values is requiring some special handling and it is the
46 * goal of the LogSumExp function.
47 *
48 * If the values are x1...xn, the function is computing:
49 *
50 * ln(exp(x1) + ... + exp(xn)) and the computation is done in such a way that
51 * rounding issues are minimised.
52 *
53 * The max xm of the values is extracted and the function is computing:
54 * xm + ln(exp(x1 - xm) + ... + exp(xn - xm))
55 *
56 * @param[in] *in Pointer to an array of input values.
57 * @param[in] blockSize Number of samples in the input array.
58 * @return LogSumExp
59 *
60 */
61
62 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
63
64 #include "arm_helium_utils.h"
65 #include "arm_vec_math.h"
66
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)67 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
68 {
69 float32_t maxVal;
70 const float32_t *pIn;
71 int32_t blkCnt;
72 float32_t accum=0.0f;
73 float32_t tmp;
74
75
76 arm_max_no_idx_f32((float32_t *) in, blockSize, &maxVal);
77
78
79 blkCnt = blockSize;
80 pIn = in;
81
82
83 f32x4_t vSum = vdupq_n_f32(0.0f);
84 blkCnt = blockSize >> 2;
85 while(blkCnt > 0)
86 {
87 f32x4_t vecIn = vld1q(pIn);
88 f32x4_t vecExp;
89
90 vecExp = vexpq_f32(vsubq_n_f32(vecIn, maxVal));
91
92 vSum = vaddq_f32(vSum, vecExp);
93
94 /*
95 * Decrement the blockSize loop counter
96 * Advance vector source and destination pointers
97 */
98 pIn += 4;
99 blkCnt --;
100 }
101
102 /* sum + log */
103 accum = vecAddAcrossF32Mve(vSum);
104
105 blkCnt = blockSize & 0x3;
106 while(blkCnt > 0)
107 {
108 tmp = *pIn++;
109 accum += expf(tmp - maxVal);
110 blkCnt--;
111
112 }
113
114 accum = maxVal + log(accum);
115
116 return (accum);
117 }
118
119 #else
120 #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
121
122 #include "NEMath.h"
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)123 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
124 {
125 float32_t maxVal;
126 float32_t tmp;
127 float32x4_t tmpV, tmpVb;
128 float32x4_t maxValV;
129 uint32x4_t idxV;
130 float32x4_t accumV;
131 float32x2_t accumV2;
132
133 const float32_t *pIn;
134 uint32_t blkCnt;
135 float32_t accum;
136
137 pIn = in;
138
139 blkCnt = blockSize;
140
141 if (blockSize <= 3)
142 {
143 maxVal = *pIn++;
144 blkCnt--;
145
146 while(blkCnt > 0)
147 {
148 tmp = *pIn++;
149
150 if (tmp > maxVal)
151 {
152 maxVal = tmp;
153 }
154 blkCnt--;
155 }
156 }
157 else
158 {
159 maxValV = vld1q_f32(pIn);
160 pIn += 4;
161 blkCnt = (blockSize - 4) >> 2;
162
163 while(blkCnt > 0)
164 {
165 tmpVb = vld1q_f32(pIn);
166 pIn += 4;
167
168 idxV = vcgtq_f32(tmpVb, maxValV);
169 maxValV = vbslq_f32(idxV, tmpVb, maxValV );
170
171 blkCnt--;
172 }
173
174 accumV2 = vpmax_f32(vget_low_f32(maxValV),vget_high_f32(maxValV));
175 accumV2 = vpmax_f32(accumV2,accumV2);
176 maxVal = vget_lane_f32(accumV2, 0) ;
177
178 blkCnt = (blockSize - 4) & 3;
179
180 while(blkCnt > 0)
181 {
182 tmp = *pIn++;
183
184 if (tmp > maxVal)
185 {
186 maxVal = tmp;
187 }
188 blkCnt--;
189 }
190
191 }
192
193
194
195 maxValV = vdupq_n_f32(maxVal);
196 pIn = in;
197 accum = 0;
198 accumV = vdupq_n_f32(0.0f);
199
200 blkCnt = blockSize >> 2;
201
202 while(blkCnt > 0)
203 {
204 tmpV = vld1q_f32(pIn);
205 pIn += 4;
206 tmpV = vsubq_f32(tmpV, maxValV);
207 tmpV = vexpq_f32(tmpV);
208 accumV = vaddq_f32(accumV, tmpV);
209
210 blkCnt--;
211
212 }
213 accumV2 = vpadd_f32(vget_low_f32(accumV),vget_high_f32(accumV));
214 accum = vget_lane_f32(accumV2, 0) + vget_lane_f32(accumV2, 1);
215
216 blkCnt = blockSize & 0x3;
217 while(blkCnt > 0)
218 {
219 tmp = *pIn++;
220 accum += expf(tmp - maxVal);
221 blkCnt--;
222
223 }
224
225 accum = maxVal + logf(accum);
226
227 return(accum);
228 }
229 #else
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)230 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
231 {
232 float32_t maxVal;
233 float32_t tmp;
234 const float32_t *pIn;
235 uint32_t blkCnt;
236 float32_t accum;
237
238 pIn = in;
239 blkCnt = blockSize;
240
241 maxVal = *pIn++;
242 blkCnt--;
243
244 while(blkCnt > 0)
245 {
246 tmp = *pIn++;
247
248 if (tmp > maxVal)
249 {
250 maxVal = tmp;
251 }
252 blkCnt--;
253
254 }
255
256 blkCnt = blockSize;
257 pIn = in;
258 accum = 0;
259 while(blkCnt > 0)
260 {
261 tmp = *pIn++;
262 accum += expf(tmp - maxVal);
263 blkCnt--;
264
265 }
266 accum = maxVal + logf(accum);
267
268 return(accum);
269 }
270 #endif
271 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
272
273 /**
274 * @} end of groupStats group
275 */
276