• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_svm_rbf_predict_f16.c
4  * Description:  SVM Radial Basis Function Classifier
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/svm_functions_f16.h"
30 
31 #if defined(ARM_FLOAT16_SUPPORTED)
32 
33 #include <limits.h>
34 #include <math.h>
35 
36 
37 /**
38  * @addtogroup rbfsvm
39  * @{
40  */
41 
42 
43 /**
44  * @brief SVM rbf prediction
45  * @param[in]    S         Pointer to an instance of the rbf SVM structure.
46  * @param[in]    in        Pointer to input vector
47  * @param[out]   pResult   decision value
48  * @return none.
49  *
50  */
51 
52 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
53 
54 #include "arm_helium_utils.h"
55 #include "arm_vec_math_f16.h"
56 
arm_svm_rbf_predict_f16(const arm_svm_rbf_instance_f16 * S,const float16_t * in,int32_t * pResult)57 void arm_svm_rbf_predict_f16(
58     const arm_svm_rbf_instance_f16 *S,
59     const float16_t * in,
60     int32_t * pResult)
61 {
62         /* inlined Matrix x Vector function interleaved with dot prod */
63     uint32_t        numRows = S->nbOfSupportVectors;
64     uint32_t        numCols = S->vectorDimension;
65     const float16_t *pSupport = S->supportVectors;
66     const float16_t *pSrcA = pSupport;
67     const float16_t *pInA0;
68     const float16_t *pInA1;
69     uint32_t         row;
70     uint32_t         blkCnt;     /* loop counters */
71     const float16_t *pDualCoef = S->dualCoefficients;
72     _Float16       sum = S->intercept;
73     f16x8_t         vSum = vdupq_n_f16(0);
74 
75     row = numRows;
76 
77     /*
78      * compute 4 rows in parrallel
79      */
80     while (row >= 4) {
81         const float16_t *pInA2, *pInA3;
82         float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
83         f16x8_t         vecIn, acc0, acc1, acc2, acc3;
84         float16_t const *pSrcVecPtr = in;
85 
86         /*
87          * Initialize the pointers to 4 consecutive MatrixA rows
88          */
89         pInA0 = pSrcA;
90         pInA1 = pInA0 + numCols;
91         pInA2 = pInA1 + numCols;
92         pInA3 = pInA2 + numCols;
93         /*
94          * Initialize the vector pointer
95          */
96         pInVec = pSrcVecPtr;
97         /*
98          * reset accumulators
99          */
100         acc0 = vdupq_n_f16(0.0f);
101         acc1 = vdupq_n_f16(0.0f);
102         acc2 = vdupq_n_f16(0.0f);
103         acc3 = vdupq_n_f16(0.0f);
104 
105         pSrcA0Vec = pInA0;
106         pSrcA1Vec = pInA1;
107         pSrcA2Vec = pInA2;
108         pSrcA3Vec = pInA3;
109 
110         blkCnt = numCols >> 3;
111         while (blkCnt > 0U) {
112             f16x8_t         vecA;
113             f16x8_t         vecDif;
114 
115             vecIn = vld1q(pInVec);
116             pInVec += 8;
117             vecA = vld1q(pSrcA0Vec);
118             pSrcA0Vec += 8;
119             vecDif = vsubq(vecIn, vecA);
120             acc0 = vfmaq(acc0, vecDif, vecDif);
121             vecA = vld1q(pSrcA1Vec);
122             pSrcA1Vec += 8;
123             vecDif = vsubq(vecIn, vecA);
124             acc1 = vfmaq(acc1, vecDif, vecDif);
125             vecA = vld1q(pSrcA2Vec);
126             pSrcA2Vec += 8;
127             vecDif = vsubq(vecIn, vecA);
128             acc2 = vfmaq(acc2, vecDif, vecDif);
129             vecA = vld1q(pSrcA3Vec);
130             pSrcA3Vec += 8;
131             vecDif = vsubq(vecIn, vecA);
132             acc3 = vfmaq(acc3, vecDif, vecDif);
133 
134             blkCnt--;
135         }
136         /*
137          * tail
138          * (will be merged thru tail predication)
139          */
140         blkCnt = numCols & 7;
141         if (blkCnt > 0U) {
142             mve_pred16_t    p0 = vctp16q(blkCnt);
143             f16x8_t         vecA;
144             f16x8_t         vecDif;
145 
146             vecIn = vldrhq_z_f16(pInVec, p0);
147             vecA = vldrhq_z_f16(pSrcA0Vec, p0);
148             vecDif = vsubq(vecIn, vecA);
149             acc0 = vfmaq(acc0, vecDif, vecDif);
150             vecA = vldrhq_z_f16(pSrcA1Vec, p0);
151             vecDif = vsubq(vecIn, vecA);
152             acc1 = vfmaq(acc1, vecDif, vecDif);
153             vecA = vldrhq_z_f16(pSrcA2Vec, p0);;
154             vecDif = vsubq(vecIn, vecA);
155             acc2 = vfmaq(acc2, vecDif, vecDif);
156             vecA = vldrhq_z_f16(pSrcA3Vec, p0);
157             vecDif = vsubq(vecIn, vecA);
158             acc3 = vfmaq(acc3, vecDif, vecDif);
159         }
160         /*
161          * Sum the partial parts
162          */
163 
164         //sum += *pDualCoef++ * expf(-S->gamma * vecReduceF16Mve(acc0));
165         f16x8_t         vtmp = vuninitializedq_f16();
166         vtmp = vsetq_lane(vecAddAcrossF16Mve(acc0), vtmp, 0);
167         vtmp = vsetq_lane(vecAddAcrossF16Mve(acc1), vtmp, 1);
168         vtmp = vsetq_lane(vecAddAcrossF16Mve(acc2), vtmp, 2);
169         vtmp = vsetq_lane(vecAddAcrossF16Mve(acc3), vtmp, 3);
170 
171         vSum =
172             vfmaq_m_f16(vSum, vld1q(pDualCoef),
173                       vexpq_f16(vmulq_n_f16(vtmp, -S->gamma)),vctp16q(4));
174         pDualCoef += 4;
175         pSrcA += numCols * 4;
176         /*
177          * Decrement the row loop counter
178          */
179         row -= 4;
180     }
181 
182     /*
183      * compute 2 rows in parrallel
184      */
185     if (row >= 2) {
186         float16_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec;
187         f16x8_t         vecIn, acc0, acc1;
188         float16_t const *pSrcVecPtr = in;
189 
190         /*
191          * Initialize the pointers to 2 consecutive MatrixA rows
192          */
193         pInA0 = pSrcA;
194         pInA1 = pInA0 + numCols;
195         /*
196          * Initialize the vector pointer
197          */
198         pInVec = pSrcVecPtr;
199         /*
200          * reset accumulators
201          */
202         acc0 = vdupq_n_f16(0.0f);
203         acc1 = vdupq_n_f16(0.0f);
204         pSrcA0Vec = pInA0;
205         pSrcA1Vec = pInA1;
206 
207         blkCnt = numCols >> 3;
208         while (blkCnt > 0U) {
209             f16x8_t         vecA;
210             f16x8_t         vecDif;
211 
212             vecIn = vld1q(pInVec);
213             pInVec += 8;
214             vecA = vld1q(pSrcA0Vec);
215             pSrcA0Vec += 8;
216             vecDif = vsubq(vecIn, vecA);
217             acc0 = vfmaq(acc0, vecDif, vecDif);;
218             vecA = vld1q(pSrcA1Vec);
219             pSrcA1Vec += 8;
220             vecDif = vsubq(vecIn, vecA);
221             acc1 = vfmaq(acc1, vecDif, vecDif);
222 
223             blkCnt--;
224         }
225         /*
226          * tail
227          * (will be merged thru tail predication)
228          */
229         blkCnt = numCols & 7;
230         if (blkCnt > 0U) {
231             mve_pred16_t    p0 = vctp16q(blkCnt);
232             f16x8_t         vecA, vecDif;
233 
234             vecIn = vldrhq_z_f16(pInVec, p0);
235             vecA = vldrhq_z_f16(pSrcA0Vec, p0);
236             vecDif = vsubq(vecIn, vecA);
237             acc0 = vfmaq(acc0, vecDif, vecDif);
238             vecA = vldrhq_z_f16(pSrcA1Vec, p0);
239             vecDif = vsubq(vecIn, vecA);
240             acc1 = vfmaq(acc1, vecDif, vecDif);
241         }
242         /*
243          * Sum the partial parts
244          */
245         f16x8_t         vtmp = vuninitializedq_f16();
246         vtmp = vsetq_lane(vecAddAcrossF16Mve(acc0), vtmp, 0);
247         vtmp = vsetq_lane(vecAddAcrossF16Mve(acc1), vtmp, 1);
248 
249         vSum =
250             vfmaq_m_f16(vSum, vld1q(pDualCoef),
251                         vexpq_f16(vmulq_n_f16(vtmp, -S->gamma)), vctp16q(2));
252         pDualCoef += 2;
253 
254         pSrcA += numCols * 2;
255         row -= 2;
256     }
257 
258     if (row >= 1) {
259         f16x8_t         vecIn, acc0;
260         float16_t const *pSrcA0Vec, *pInVec;
261         float16_t const *pSrcVecPtr = in;
262         /*
263          * Initialize the pointers to last MatrixA row
264          */
265         pInA0 = pSrcA;
266         /*
267          * Initialize the vector pointer
268          */
269         pInVec = pSrcVecPtr;
270         /*
271          * reset accumulators
272          */
273         acc0 = vdupq_n_f16(0.0f);
274 
275         pSrcA0Vec = pInA0;
276 
277         blkCnt = numCols >> 3;
278         while (blkCnt > 0U) {
279             f16x8_t         vecA, vecDif;
280 
281             vecIn = vld1q(pInVec);
282             pInVec += 8;
283             vecA = vld1q(pSrcA0Vec);
284             pSrcA0Vec += 8;
285             vecDif = vsubq(vecIn, vecA);
286             acc0 = vfmaq(acc0, vecDif, vecDif);
287 
288             blkCnt--;
289         }
290         /*
291          * tail
292          * (will be merged thru tail predication)
293          */
294         blkCnt = numCols & 7;
295         if (blkCnt > 0U) {
296             mve_pred16_t    p0 = vctp16q(blkCnt);
297             f16x8_t         vecA, vecDif;
298 
299             vecIn = vldrhq_z_f16(pInVec, p0);
300             vecA = vldrhq_z_f16(pSrcA0Vec, p0);
301             vecDif = vsubq(vecIn, vecA);
302             acc0 = vfmaq(acc0, vecDif, vecDif);
303         }
304         /*
305          * Sum the partial parts
306          */
307         f16x8_t         vtmp = vuninitializedq_f16();
308         vtmp = vsetq_lane(vecAddAcrossF16Mve(acc0), vtmp, 0);
309 
310         vSum =
311             vfmaq_m_f16(vSum, vld1q(pDualCoef),
312                         vexpq_f16(vmulq_n_f16(vtmp, -S->gamma)), vctp16q(1));
313 
314     }
315 
316 
317     sum += vecAddAcrossF16Mve(vSum);
318     *pResult = S->classes[STEP(sum)];
319 }
320 
321 #else
arm_svm_rbf_predict_f16(const arm_svm_rbf_instance_f16 * S,const float16_t * in,int32_t * pResult)322 void arm_svm_rbf_predict_f16(
323     const arm_svm_rbf_instance_f16 *S,
324     const float16_t * in,
325     int32_t * pResult)
326 {
327     _Float16 sum=S->intercept;
328     _Float16 dot=00.f16;
329     uint32_t i,j;
330     const float16_t *pSupport = S->supportVectors;
331 
332     for(i=0; i < S->nbOfSupportVectors; i++)
333     {
334         dot=0.0f16;
335         for(j=0; j < S->vectorDimension; j++)
336         {
337             dot = dot + SQ((_Float16)in[j] - (_Float16) *pSupport);
338             pSupport++;
339         }
340         sum += (_Float16)S->dualCoefficients[i] * (_Float16)expf(-(_Float16)S->gamma * dot);
341     }
342     *pResult=S->classes[STEP(sum)];
343 }
344 
345 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
346 
347 /**
348  * @} end of rbfsvm group
349  */
350 
351 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
352 
353