1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <jni.h>
18 #include <fstream>
19 #include "common/ms_log.h"
20 #include "include/lite_session.h"
21 #include "include/errorcode.h"
Java_com_mindspore_lite_LiteSession_createSessionWithModel(JNIEnv * env,jobject thiz,jobject model_buffer,jlong ms_config_ptr)22 extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSessionWithModel(JNIEnv *env, jobject thiz,
23 jobject model_buffer,
24 jlong ms_config_ptr) {
25 // decode model buffer and buffer size
26 if (model_buffer == nullptr) {
27 MS_LOGE("Buffer from java is nullptr");
28 return reinterpret_cast<jlong>(nullptr);
29 }
30 jlong buffer_len = env->GetDirectBufferCapacity(model_buffer);
31 auto *model_buf = static_cast<char *>(env->GetDirectBufferAddress(model_buffer));
32 // decode ms context
33 auto *pointer = reinterpret_cast<void *>(ms_config_ptr);
34 if (pointer == nullptr) {
35 MS_LOGE("Context pointer from java is nullptr");
36 return jlong(nullptr);
37 }
38 auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
39 // create session
40 auto session = mindspore::session::LiteSession::CreateSession(model_buf, buffer_len, lite_context_ptr);
41 if (session == nullptr) {
42 MS_LOGE("CreateSession failed");
43 return jlong(nullptr);
44 }
45 return jlong(session);
46 }
47
Java_com_mindspore_lite_LiteSession_createSession(JNIEnv * env,jobject thiz,jlong ms_config_ptr)48 extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSession(JNIEnv *env, jobject thiz,
49 jlong ms_config_ptr) {
50 auto *pointer = reinterpret_cast<void *>(ms_config_ptr);
51 if (pointer == nullptr) {
52 MS_LOGE("Context pointer from java is nullptr");
53 return jlong(nullptr);
54 }
55 auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
56 auto session = mindspore::session::LiteSession::CreateSession(lite_context_ptr);
57 if (session == nullptr) {
58 MS_LOGE("CreateSession failed");
59 return jlong(nullptr);
60 }
61 return jlong(session);
62 }
63
Java_com_mindspore_lite_LiteSession_compileGraph(JNIEnv * env,jobject thiz,jlong session_ptr,jlong model_ptr)64 extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_compileGraph(JNIEnv *env, jobject thiz,
65 jlong session_ptr,
66 jlong model_ptr) {
67 auto *session_pointer = reinterpret_cast<void *>(session_ptr);
68 if (session_pointer == nullptr) {
69 MS_LOGE("Session pointer from java is nullptr");
70 return (jboolean) false;
71 }
72 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
73 auto *model_pointer = reinterpret_cast<void *>(model_ptr);
74 if (model_pointer == nullptr) {
75 MS_LOGE("Model pointer from java is nullptr");
76 return (jboolean) false;
77 }
78 auto *lite_model_ptr = static_cast<mindspore::lite::Model *>(model_pointer);
79
80 auto ret = lite_session_ptr->CompileGraph(lite_model_ptr);
81 return (jboolean)(ret == mindspore::lite::RET_OK);
82 }
83
Java_com_mindspore_lite_LiteSession_bindThread(JNIEnv * env,jobject thiz,jlong session_ptr,jboolean if_bind)84 extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_bindThread(JNIEnv *env, jobject thiz,
85 jlong session_ptr, jboolean if_bind) {
86 auto *pointer = reinterpret_cast<void *>(session_ptr);
87 if (pointer == nullptr) {
88 MS_LOGE("Session pointer from java is nullptr");
89 return;
90 }
91 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
92 lite_session_ptr->BindThread(if_bind);
93 }
94
Java_com_mindspore_lite_LiteSession_runGraph(JNIEnv * env,jobject thiz,jlong session_ptr)95 extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_runGraph(JNIEnv *env, jobject thiz,
96 jlong session_ptr) {
97 auto *pointer = reinterpret_cast<void *>(session_ptr);
98 if (pointer == nullptr) {
99 MS_LOGE("Session pointer from java is nullptr");
100 return (jboolean) false;
101 }
102 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
103 auto ret = lite_session_ptr->RunGraph();
104 return (jboolean)(ret == mindspore::lite::RET_OK);
105 }
106
Java_com_mindspore_lite_LiteSession_getInputs(JNIEnv * env,jobject thiz,jlong session_ptr)107 extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInputs(JNIEnv *env, jobject thiz,
108 jlong session_ptr) {
109 jclass array_list = env->FindClass("java/util/ArrayList");
110 jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
111 jobject ret = env->NewObject(array_list, array_list_construct);
112 jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
113
114 jclass long_object = env->FindClass("java/lang/Long");
115 jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
116 auto *pointer = reinterpret_cast<void *>(session_ptr);
117 if (pointer == nullptr) {
118 MS_LOGE("Session pointer from java is nullptr");
119 return ret;
120 }
121 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
122 auto inputs = lite_session_ptr->GetInputs();
123 for (auto input : inputs) {
124 jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
125 env->CallBooleanMethod(ret, array_list_add, tensor_addr);
126 }
127 return ret;
128 }
129
Java_com_mindspore_lite_LiteSession_getInputsByTensorName(JNIEnv * env,jobject thiz,jlong session_ptr,jstring tensor_name)130 extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getInputsByTensorName(JNIEnv *env, jobject thiz,
131 jlong session_ptr,
132 jstring tensor_name) {
133 auto *pointer = reinterpret_cast<void *>(session_ptr);
134 if (pointer == nullptr) {
135 MS_LOGE("Session pointer from java is nullptr");
136 return jlong(nullptr);
137 }
138 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
139 auto input = lite_session_ptr->GetInputsByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
140 return jlong(input);
141 }
142
Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv * env,jobject thiz,jlong session_ptr,jstring node_name)143 extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv *env, jobject thiz,
144 jlong session_ptr,
145 jstring node_name) {
146 jclass array_list = env->FindClass("java/util/ArrayList");
147 jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
148 jobject ret = env->NewObject(array_list, array_list_construct);
149 jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
150
151 jclass long_object = env->FindClass("java/lang/Long");
152 jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
153 auto *pointer = reinterpret_cast<void *>(session_ptr);
154 if (pointer == nullptr) {
155 MS_LOGE("Session pointer from java is nullptr");
156 return ret;
157 }
158 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
159 auto inputs = lite_session_ptr->GetOutputsByNodeName(env->GetStringUTFChars(node_name, JNI_FALSE));
160 for (auto input : inputs) {
161 jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
162 env->CallBooleanMethod(ret, array_list_add, tensor_addr);
163 env->DeleteLocalRef(tensor_addr);
164 }
165 return ret;
166 }
167
Java_com_mindspore_lite_LiteSession_getOutputMapByTensor(JNIEnv * env,jobject thiz,jlong session_ptr)168 extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputMapByTensor(JNIEnv *env, jobject thiz,
169 jlong session_ptr) {
170 jclass hash_map_clazz = env->FindClass("java/util/HashMap");
171 jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V");
172 jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct);
173 jmethodID hash_map_put =
174 env->GetMethodID(hash_map_clazz, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
175 auto *pointer = reinterpret_cast<void *>(session_ptr);
176 if (pointer == nullptr) {
177 MS_LOGE("Session pointer from java is nullptr");
178 return hash_map;
179 }
180 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
181 auto outputs = lite_session_ptr->GetOutputs();
182 jclass long_object = env->FindClass("java/lang/Long");
183 jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
184 for (const auto &output_iter : outputs) {
185 auto node_name = output_iter.first;
186 auto ms_tensor = output_iter.second;
187 jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor));
188 env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), tensor_addr);
189 env->DeleteLocalRef(tensor_addr);
190 }
191 return hash_map;
192 }
193
Java_com_mindspore_lite_LiteSession_getOutputTensorNames(JNIEnv * env,jobject thiz,jlong session_ptr)194 extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputTensorNames(JNIEnv *env, jobject thiz,
195 jlong session_ptr) {
196 jclass array_list = env->FindClass("java/util/ArrayList");
197 jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
198 jobject ret = env->NewObject(array_list, array_list_construct);
199 jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
200
201 auto *pointer = reinterpret_cast<void *>(session_ptr);
202 if (pointer == nullptr) {
203 MS_LOGE("Session pointer from java is nullptr");
204 return ret;
205 }
206 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
207 auto output_names = lite_session_ptr->GetOutputTensorNames();
208 for (const auto &output_name : output_names) {
209 env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
210 }
211 return ret;
212 }
213
Java_com_mindspore_lite_LiteSession_getOutputByTensorName(JNIEnv * env,jobject thiz,jlong session_ptr,jstring tensor_name)214 extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getOutputByTensorName(JNIEnv *env, jobject thiz,
215 jlong session_ptr,
216 jstring tensor_name) {
217 auto *pointer = reinterpret_cast<void *>(session_ptr);
218 if (pointer == nullptr) {
219 MS_LOGE("Session pointer from java is nullptr");
220 return jlong(nullptr);
221 }
222 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
223 auto output = lite_session_ptr->GetOutputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
224 return jlong(output);
225 }
226
Java_com_mindspore_lite_LiteSession_free(JNIEnv * env,jobject thiz,jlong session_ptr)227 extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_free(JNIEnv *env, jobject thiz,
228 jlong session_ptr) {
229 auto *pointer = reinterpret_cast<void *>(session_ptr);
230 if (pointer == nullptr) {
231 MS_LOGE("Session pointer from java is nullptr");
232 return;
233 }
234 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
235 delete (lite_session_ptr);
236 }
237
Java_com_mindspore_lite_LiteSession_resize(JNIEnv * env,jobject thiz,jlong session_ptr,jlongArray inputs,jobjectArray dims)238 extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_resize(JNIEnv *env, jobject thiz,
239 jlong session_ptr, jlongArray inputs,
240 jobjectArray dims) {
241 std::vector<std::vector<int>> c_dims;
242 auto *pointer = reinterpret_cast<void *>(session_ptr);
243 if (pointer == nullptr) {
244 MS_LOGE("Session pointer from java is nullptr");
245 return false;
246 }
247 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
248
249 auto input_size = static_cast<int>(env->GetArrayLength(inputs));
250 jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
251 std::vector<mindspore::tensor::MSTensor *> c_inputs;
252 for (int i = 0; i < input_size; i++) {
253 auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
254 if (tensor_pointer == nullptr) {
255 MS_LOGE("Tensor pointer from java is nullptr");
256 return false;
257 }
258 auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(tensor_pointer);
259 c_inputs.push_back(ms_tensor_ptr);
260 }
261 auto tensor_size = static_cast<int>(env->GetArrayLength(dims));
262 for (int i = 0; i < tensor_size; i++) {
263 auto array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i));
264 auto dim_size = static_cast<int>(env->GetArrayLength(array));
265 jint *dim_data = env->GetIntArrayElements(array, nullptr);
266 std::vector<int> tensor_dims(dim_size);
267 for (int j = 0; j < dim_size; j++) {
268 tensor_dims[j] = dim_data[j];
269 }
270 c_dims.push_back(tensor_dims);
271 env->ReleaseIntArrayElements(array, dim_data, JNI_ABORT);
272 env->DeleteLocalRef(array);
273 }
274 int ret = lite_session_ptr->Resize(c_inputs, c_dims);
275 return (jboolean)(ret == mindspore::lite::RET_OK);
276 }
277
278
Java_com_mindspore_lite_LiteSession_export(JNIEnv * env,jobject thiz,jlong session_ptr,jstring model_name,jint model_type,jint quantization_type)279 extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_export(JNIEnv *env, jobject thiz,
280 jlong session_ptr,
281 jstring model_name,
282 jint model_type,
283 jint quantization_type) {
284 auto *session_pointer = reinterpret_cast<void *>(session_ptr);
285 if (session_pointer == nullptr) {
286 MS_LOGE("Session pointer from java is nullptr");
287 return (jboolean) false;
288 }
289 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
290 auto ret = lite_session_ptr->Export(env->GetStringUTFChars(model_name, JNI_FALSE),
291 static_cast<mindspore::lite::ModelType>(model_type),
292 static_cast<mindspore::lite::QuantizationType>(quantization_type));
293 return (jboolean)(ret == 0);
294 }
295
Java_com_mindspore_lite_LiteSession_train(JNIEnv * env,jobject thiz,jlong session_ptr)296 extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_train(JNIEnv *env, jobject thiz,
297 jlong session_ptr) {
298 auto *session_pointer = reinterpret_cast<void *>(session_ptr);
299 if (session_pointer == nullptr) {
300 MS_LOGE("Session pointer from java is nullptr");
301 return (jboolean) false;
302 }
303 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
304 auto ret = lite_session_ptr->Train();
305 return (jboolean)(ret == mindspore::lite::RET_OK);
306 }
307
Java_com_mindspore_lite_LiteSession_eval(JNIEnv * env,jobject thiz,jlong session_ptr)308 extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_eval(JNIEnv *env, jobject thiz,
309 jlong session_ptr) {
310 auto *session_pointer = reinterpret_cast<void *>(session_ptr);
311 if (session_pointer == nullptr) {
312 MS_LOGE("Session pointer from java is nullptr");
313 return (jboolean) false;
314 }
315 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
316 auto ret = lite_session_ptr->Eval();
317 return (jboolean)(ret == mindspore::lite::RET_OK);
318 }
319
Java_com_mindspore_lite_LiteSession_isTrain(JNIEnv * env,jobject thiz,jlong session_ptr)320 extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_isTrain(JNIEnv *env, jobject thiz,
321 jlong session_ptr) {
322 auto *session_pointer = reinterpret_cast<void *>(session_ptr);
323 if (session_pointer == nullptr) {
324 MS_LOGE("Session pointer from java is nullptr");
325 return (jboolean) false;
326 }
327 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
328 auto ret = lite_session_ptr->IsTrain();
329 return (jboolean)(ret);
330 }
331
Java_com_mindspore_lite_LiteSession_isEval(JNIEnv * env,jobject thiz,jlong session_ptr)332 extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_isEval(JNIEnv *env, jobject thiz,
333 jlong session_ptr) {
334 auto *session_pointer = reinterpret_cast<void *>(session_ptr);
335 if (session_pointer == nullptr) {
336 MS_LOGE("Session pointer from java is nullptr");
337 return (jboolean) false;
338 }
339 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
340 auto ret = lite_session_ptr->IsEval();
341 return (jboolean)(ret);
342 }
343
Java_com_mindspore_lite_LiteSession_setLearningRate(JNIEnv * env,jobject thiz,jlong session_ptr,jfloat learning_rate)344 extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setLearningRate(JNIEnv *env, jobject thiz,
345 jlong session_ptr,
346 jfloat learning_rate) {
347 auto *session_pointer = reinterpret_cast<void *>(session_ptr);
348 if (session_pointer == nullptr) {
349 MS_LOGE("Session pointer from java is nullptr");
350 return (jboolean) false;
351 }
352 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
353 auto ret = lite_session_ptr->SetLearningRate(learning_rate);
354 return (jboolean)(ret == mindspore::lite::RET_OK);
355 }
356
Java_com_mindspore_lite_LiteSession_setupVirtualBatch(JNIEnv * env,jobject thiz,jlong session_ptr,jint virtual_batch_factor,jfloat learning_rate,jfloat momentum)357 extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setupVirtualBatch(JNIEnv *env, jobject thiz,
358 jlong session_ptr,
359 jint virtual_batch_factor,
360 jfloat learning_rate,
361 jfloat momentum) {
362 auto *session_pointer = reinterpret_cast<void *>(session_ptr);
363 if (session_pointer == nullptr) {
364 MS_LOGE("Session pointer from java is nullptr");
365 return (jboolean) false;
366 }
367 auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
368 auto ret = lite_session_ptr->SetupVirtualBatch(virtual_batch_factor, learning_rate, momentum);
369 return (jboolean)(ret == mindspore::lite::RET_OK);
370 }
371
Java_com_mindspore_lite_LiteSession_updateFeatures(JNIEnv * env,jclass,jlong session_ptr,jlongArray features)372 extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_updateFeatures(JNIEnv *env, jclass,
373 jlong session_ptr,
374 jlongArray features) {
375 jsize size = static_cast<int>(env->GetArrayLength(features));
376 jlong *input_data = env->GetLongArrayElements(features, nullptr);
377 std::vector<mindspore::tensor::MSTensor *> newFeatures;
378 for (int i = 0; i < size; ++i) {
379 auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
380 if (tensor_pointer == nullptr) {
381 MS_LOGE("Tensor pointer from java is nullptr");
382 return false;
383 }
384 auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(tensor_pointer);
385 newFeatures.emplace_back(ms_tensor_ptr);
386 }
387 auto session = reinterpret_cast<mindspore::session::LiteSession *>(session_ptr);
388 auto ret = session->UpdateFeatureMaps(newFeatures);
389 return (jboolean)(ret == mindspore::lite::RET_OK);
390 }
391
Java_com_mindspore_lite_LiteSession_getFeaturesMap(JNIEnv * env,jobject thiz,jlong session_ptr)392 extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getFeaturesMap(JNIEnv *env, jobject thiz,
393 jlong session_ptr) {
394 jclass array_list = env->FindClass("java/util/ArrayList");
395 jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
396 jobject ret = env->NewObject(array_list, array_list_construct);
397 jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
398
399 jclass long_object = env->FindClass("java/lang/Long");
400 jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
401 auto *pointer = reinterpret_cast<void *>(session_ptr);
402 if (pointer == nullptr) {
403 MS_LOGE("Session pointer from java is nullptr");
404 return ret;
405 }
406 auto *train_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
407 auto inputs = train_session_ptr->GetFeatureMaps();
408 for (auto input : inputs) {
409 jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
410 env->CallBooleanMethod(ret, array_list_add, tensor_addr);
411 }
412 return ret;
413 }
414