• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2011 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "jni/jni_multiclass_pa.h"
18 #include "native/multiclass_pa.h"
19 
20 #include <vector>
21 
22 using learningfw::MulticlassPA;
23 using std::vector;
24 using std::pair;
25 
CreateIndexValuePairs(const int * indices,const float * values,const int length,vector<pair<int,float>> * pairs)26 void CreateIndexValuePairs(const int* indices, const float* values,
27                            const int length, vector<pair<int, float> >* pairs) {
28   pairs->clear();
29 
30   for (int i = 0; i < length; ++i) {
31     pair<int, float> new_pair(indices[i], values[i]);
32     pairs->push_back(new_pair);
33   }
34 }
35 
Java_android_bordeaux_learning_MulticlassPA_initNativeClassifier(JNIEnv *,jobject,jint num_classes,jint num_dims,jfloat aggressiveness)36 jlong Java_android_bordeaux_learning_MulticlassPA_initNativeClassifier(JNIEnv* /* env */,
37                                                        jobject /* thiz */,
38                                                        jint num_classes,
39                                                        jint num_dims,
40                                                        jfloat aggressiveness) {
41   MulticlassPA* classifier = new MulticlassPA(num_classes,
42                                               num_dims,
43                                               aggressiveness);
44   return ((jlong) classifier);
45 }
46 
47 
Java_android_bordeaux_learning_MulticlassPA_deleteNativeClassifier(JNIEnv *,jobject,jlong paPtr)48 jboolean Java_android_bordeaux_learning_MulticlassPA_deleteNativeClassifier(JNIEnv* /* env */,
49                                                              jobject /* thiz */,
50                                                              jlong paPtr) {
51   MulticlassPA* classifier = (MulticlassPA*) paPtr;
52   delete classifier;
53   return JNI_TRUE;
54 }
55 
Java_android_bordeaux_learning_MulticlassPA_nativeSparseTrainOneExample(JNIEnv * env,jobject,jintArray index_array,jfloatArray value_array,jint target,jlong paPtr)56 jboolean Java_android_bordeaux_learning_MulticlassPA_nativeSparseTrainOneExample(JNIEnv* env,
57                                                                   jobject /* thiz */,
58                                                                   jintArray index_array,
59                                                                   jfloatArray value_array,
60                                                                   jint target,
61                                                                   jlong paPtr) {
62   MulticlassPA* classifier = (MulticlassPA*) paPtr;
63 
64   if (classifier && index_array && value_array) {
65 
66     jfloat* values = env->GetFloatArrayElements(value_array, NULL);
67     jint* indices = env->GetIntArrayElements(index_array, NULL);
68     const int value_len = env->GetArrayLength(value_array);
69     const int index_len = env->GetArrayLength(index_array);
70 
71     if (values && indices && value_len == index_len) {
72       vector<pair<int, float> > inputs;
73 
74       CreateIndexValuePairs(indices, values, value_len, &inputs);
75       classifier->SparseTrainOneExample(inputs, target);
76       env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
77       env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
78 
79       return JNI_TRUE;
80     }
81     env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
82     env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
83   }
84 
85   return JNI_FALSE;
86 }
87 
88 
Java_android_bordeaux_learning_MulticlassPA_nativeSparseGetClass(JNIEnv * env,jobject,jintArray index_array,jfloatArray value_array,jlong paPtr)89 jint Java_android_bordeaux_learning_MulticlassPA_nativeSparseGetClass(JNIEnv* env,
90                                                        jobject /* thiz */,
91                                                        jintArray index_array,
92                                                        jfloatArray value_array,
93                                                        jlong paPtr) {
94 
95   MulticlassPA* classifier = (MulticlassPA*) paPtr;
96 
97   if (classifier && index_array && value_array) {
98 
99     jfloat* values = env->GetFloatArrayElements(value_array, NULL);
100     jint* indices = env->GetIntArrayElements(index_array, NULL);
101     const int value_len = env->GetArrayLength(value_array);
102     const int index_len = env->GetArrayLength(index_array);
103 
104     if (values && indices && value_len == index_len) {
105       vector<pair<int, float> > inputs;
106       CreateIndexValuePairs(indices, values, value_len, &inputs);
107       env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
108       env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
109       return classifier->SparseGetClass(inputs);
110     }
111     env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
112     env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
113   }
114 
115   return -1;
116 }
117