1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_svm_rbf_predict_f16.c
4 * Description: SVM Radial Basis Function Classifier
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/svm_functions_f16.h"
30
31 #if defined(ARM_FLOAT16_SUPPORTED)
32
33 #include <limits.h>
34 #include <math.h>
35
36
37 /**
38 * @addtogroup rbfsvm
39 * @{
40 */
41
42
43 /**
44 * @brief SVM rbf prediction
45 * @param[in] S Pointer to an instance of the rbf SVM structure.
46 * @param[in] in Pointer to input vector
47 * @param[out] pResult decision value
48 * @return none.
49 *
50 */
51
52 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
53
54 #include "arm_helium_utils.h"
55 #include "arm_vec_math_f16.h"
56
arm_svm_rbf_predict_f16(const arm_svm_rbf_instance_f16 * S,const float16_t * in,int32_t * pResult)57 void arm_svm_rbf_predict_f16(
58 const arm_svm_rbf_instance_f16 *S,
59 const float16_t * in,
60 int32_t * pResult)
61 {
62 /* inlined Matrix x Vector function interleaved with dot prod */
63 uint32_t numRows = S->nbOfSupportVectors;
64 uint32_t numCols = S->vectorDimension;
65 const float16_t *pSupport = S->supportVectors;
66 const float16_t *pSrcA = pSupport;
67 const float16_t *pInA0;
68 const float16_t *pInA1;
69 uint32_t row;
70 uint32_t blkCnt; /* loop counters */
71 const float16_t *pDualCoef = S->dualCoefficients;
72 _Float16 sum = S->intercept;
73 f16x8_t vSum = vdupq_n_f16(0);
74
75 row = numRows;
76
77 /*
78 * compute 4 rows in parrallel
79 */
80 while (row >= 4) {
81 const float16_t *pInA2, *pInA3;
82 float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
83 f16x8_t vecIn, acc0, acc1, acc2, acc3;
84 float16_t const *pSrcVecPtr = in;
85
86 /*
87 * Initialize the pointers to 4 consecutive MatrixA rows
88 */
89 pInA0 = pSrcA;
90 pInA1 = pInA0 + numCols;
91 pInA2 = pInA1 + numCols;
92 pInA3 = pInA2 + numCols;
93 /*
94 * Initialize the vector pointer
95 */
96 pInVec = pSrcVecPtr;
97 /*
98 * reset accumulators
99 */
100 acc0 = vdupq_n_f16(0.0f);
101 acc1 = vdupq_n_f16(0.0f);
102 acc2 = vdupq_n_f16(0.0f);
103 acc3 = vdupq_n_f16(0.0f);
104
105 pSrcA0Vec = pInA0;
106 pSrcA1Vec = pInA1;
107 pSrcA2Vec = pInA2;
108 pSrcA3Vec = pInA3;
109
110 blkCnt = numCols >> 3;
111 while (blkCnt > 0U) {
112 f16x8_t vecA;
113 f16x8_t vecDif;
114
115 vecIn = vld1q(pInVec);
116 pInVec += 8;
117 vecA = vld1q(pSrcA0Vec);
118 pSrcA0Vec += 8;
119 vecDif = vsubq(vecIn, vecA);
120 acc0 = vfmaq(acc0, vecDif, vecDif);
121 vecA = vld1q(pSrcA1Vec);
122 pSrcA1Vec += 8;
123 vecDif = vsubq(vecIn, vecA);
124 acc1 = vfmaq(acc1, vecDif, vecDif);
125 vecA = vld1q(pSrcA2Vec);
126 pSrcA2Vec += 8;
127 vecDif = vsubq(vecIn, vecA);
128 acc2 = vfmaq(acc2, vecDif, vecDif);
129 vecA = vld1q(pSrcA3Vec);
130 pSrcA3Vec += 8;
131 vecDif = vsubq(vecIn, vecA);
132 acc3 = vfmaq(acc3, vecDif, vecDif);
133
134 blkCnt--;
135 }
136 /*
137 * tail
138 * (will be merged thru tail predication)
139 */
140 blkCnt = numCols & 7;
141 if (blkCnt > 0U) {
142 mve_pred16_t p0 = vctp16q(blkCnt);
143 f16x8_t vecA;
144 f16x8_t vecDif;
145
146 vecIn = vldrhq_z_f16(pInVec, p0);
147 vecA = vldrhq_z_f16(pSrcA0Vec, p0);
148 vecDif = vsubq(vecIn, vecA);
149 acc0 = vfmaq(acc0, vecDif, vecDif);
150 vecA = vldrhq_z_f16(pSrcA1Vec, p0);
151 vecDif = vsubq(vecIn, vecA);
152 acc1 = vfmaq(acc1, vecDif, vecDif);
153 vecA = vldrhq_z_f16(pSrcA2Vec, p0);;
154 vecDif = vsubq(vecIn, vecA);
155 acc2 = vfmaq(acc2, vecDif, vecDif);
156 vecA = vldrhq_z_f16(pSrcA3Vec, p0);
157 vecDif = vsubq(vecIn, vecA);
158 acc3 = vfmaq(acc3, vecDif, vecDif);
159 }
160 /*
161 * Sum the partial parts
162 */
163
164 //sum += *pDualCoef++ * expf(-S->gamma * vecReduceF16Mve(acc0));
165 f16x8_t vtmp = vuninitializedq_f16();
166 vtmp = vsetq_lane(vecAddAcrossF16Mve(acc0), vtmp, 0);
167 vtmp = vsetq_lane(vecAddAcrossF16Mve(acc1), vtmp, 1);
168 vtmp = vsetq_lane(vecAddAcrossF16Mve(acc2), vtmp, 2);
169 vtmp = vsetq_lane(vecAddAcrossF16Mve(acc3), vtmp, 3);
170
171 vSum =
172 vfmaq_m_f16(vSum, vld1q(pDualCoef),
173 vexpq_f16(vmulq_n_f16(vtmp, -S->gamma)),vctp16q(4));
174 pDualCoef += 4;
175 pSrcA += numCols * 4;
176 /*
177 * Decrement the row loop counter
178 */
179 row -= 4;
180 }
181
182 /*
183 * compute 2 rows in parrallel
184 */
185 if (row >= 2) {
186 float16_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec;
187 f16x8_t vecIn, acc0, acc1;
188 float16_t const *pSrcVecPtr = in;
189
190 /*
191 * Initialize the pointers to 2 consecutive MatrixA rows
192 */
193 pInA0 = pSrcA;
194 pInA1 = pInA0 + numCols;
195 /*
196 * Initialize the vector pointer
197 */
198 pInVec = pSrcVecPtr;
199 /*
200 * reset accumulators
201 */
202 acc0 = vdupq_n_f16(0.0f);
203 acc1 = vdupq_n_f16(0.0f);
204 pSrcA0Vec = pInA0;
205 pSrcA1Vec = pInA1;
206
207 blkCnt = numCols >> 3;
208 while (blkCnt > 0U) {
209 f16x8_t vecA;
210 f16x8_t vecDif;
211
212 vecIn = vld1q(pInVec);
213 pInVec += 8;
214 vecA = vld1q(pSrcA0Vec);
215 pSrcA0Vec += 8;
216 vecDif = vsubq(vecIn, vecA);
217 acc0 = vfmaq(acc0, vecDif, vecDif);;
218 vecA = vld1q(pSrcA1Vec);
219 pSrcA1Vec += 8;
220 vecDif = vsubq(vecIn, vecA);
221 acc1 = vfmaq(acc1, vecDif, vecDif);
222
223 blkCnt--;
224 }
225 /*
226 * tail
227 * (will be merged thru tail predication)
228 */
229 blkCnt = numCols & 7;
230 if (blkCnt > 0U) {
231 mve_pred16_t p0 = vctp16q(blkCnt);
232 f16x8_t vecA, vecDif;
233
234 vecIn = vldrhq_z_f16(pInVec, p0);
235 vecA = vldrhq_z_f16(pSrcA0Vec, p0);
236 vecDif = vsubq(vecIn, vecA);
237 acc0 = vfmaq(acc0, vecDif, vecDif);
238 vecA = vldrhq_z_f16(pSrcA1Vec, p0);
239 vecDif = vsubq(vecIn, vecA);
240 acc1 = vfmaq(acc1, vecDif, vecDif);
241 }
242 /*
243 * Sum the partial parts
244 */
245 f16x8_t vtmp = vuninitializedq_f16();
246 vtmp = vsetq_lane(vecAddAcrossF16Mve(acc0), vtmp, 0);
247 vtmp = vsetq_lane(vecAddAcrossF16Mve(acc1), vtmp, 1);
248
249 vSum =
250 vfmaq_m_f16(vSum, vld1q(pDualCoef),
251 vexpq_f16(vmulq_n_f16(vtmp, -S->gamma)), vctp16q(2));
252 pDualCoef += 2;
253
254 pSrcA += numCols * 2;
255 row -= 2;
256 }
257
258 if (row >= 1) {
259 f16x8_t vecIn, acc0;
260 float16_t const *pSrcA0Vec, *pInVec;
261 float16_t const *pSrcVecPtr = in;
262 /*
263 * Initialize the pointers to last MatrixA row
264 */
265 pInA0 = pSrcA;
266 /*
267 * Initialize the vector pointer
268 */
269 pInVec = pSrcVecPtr;
270 /*
271 * reset accumulators
272 */
273 acc0 = vdupq_n_f16(0.0f);
274
275 pSrcA0Vec = pInA0;
276
277 blkCnt = numCols >> 3;
278 while (blkCnt > 0U) {
279 f16x8_t vecA, vecDif;
280
281 vecIn = vld1q(pInVec);
282 pInVec += 8;
283 vecA = vld1q(pSrcA0Vec);
284 pSrcA0Vec += 8;
285 vecDif = vsubq(vecIn, vecA);
286 acc0 = vfmaq(acc0, vecDif, vecDif);
287
288 blkCnt--;
289 }
290 /*
291 * tail
292 * (will be merged thru tail predication)
293 */
294 blkCnt = numCols & 7;
295 if (blkCnt > 0U) {
296 mve_pred16_t p0 = vctp16q(blkCnt);
297 f16x8_t vecA, vecDif;
298
299 vecIn = vldrhq_z_f16(pInVec, p0);
300 vecA = vldrhq_z_f16(pSrcA0Vec, p0);
301 vecDif = vsubq(vecIn, vecA);
302 acc0 = vfmaq(acc0, vecDif, vecDif);
303 }
304 /*
305 * Sum the partial parts
306 */
307 f16x8_t vtmp = vuninitializedq_f16();
308 vtmp = vsetq_lane(vecAddAcrossF16Mve(acc0), vtmp, 0);
309
310 vSum =
311 vfmaq_m_f16(vSum, vld1q(pDualCoef),
312 vexpq_f16(vmulq_n_f16(vtmp, -S->gamma)), vctp16q(1));
313
314 }
315
316
317 sum += vecAddAcrossF16Mve(vSum);
318 *pResult = S->classes[STEP(sum)];
319 }
320
321 #else
arm_svm_rbf_predict_f16(const arm_svm_rbf_instance_f16 * S,const float16_t * in,int32_t * pResult)322 void arm_svm_rbf_predict_f16(
323 const arm_svm_rbf_instance_f16 *S,
324 const float16_t * in,
325 int32_t * pResult)
326 {
327 _Float16 sum=S->intercept;
328 _Float16 dot=00.f16;
329 uint32_t i,j;
330 const float16_t *pSupport = S->supportVectors;
331
332 for(i=0; i < S->nbOfSupportVectors; i++)
333 {
334 dot=0.0f16;
335 for(j=0; j < S->vectorDimension; j++)
336 {
337 dot = dot + SQ((_Float16)in[j] - (_Float16) *pSupport);
338 pSupport++;
339 }
340 sum += (_Float16)S->dualCoefficients[i] * (_Float16)expf(-(_Float16)S->gamma * dot);
341 }
342 *pResult=S->classes[STEP(sum)];
343 }
344
345 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
346
347 /**
348 * @} end of rbfsvm group
349 */
350
351 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
352
353