1 /**
2 * Copyright 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "run_tflite.h"
18
19 #include <jni.h>
20 #include <string>
21 #include <iomanip>
22 #include <sstream>
23 #include <fcntl.h>
24
25 #include <android/asset_manager_jni.h>
26 #include <android/log.h>
27 #include <android/sharedmem.h>
28 #include <sys/mman.h>
29
30
31 extern "C"
32 JNIEXPORT jlong
33 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_initModel(JNIEnv * env,jobject,jstring _modelFileName,jboolean _useNnApi,jboolean _enableIntermediateTensorsDump,jstring _nnApiDeviceName)34 Java_com_android_nn_benchmark_core_NNTestBase_initModel(
35 JNIEnv *env,
36 jobject /* this */,
37 jstring _modelFileName,
38 jboolean _useNnApi,
39 jboolean _enableIntermediateTensorsDump,
40 jstring _nnApiDeviceName) {
41 const char *modelFileName = env->GetStringUTFChars(_modelFileName, NULL);
42 const char *nnApiDeviceName =
43 _nnApiDeviceName == NULL
44 ? NULL
45 : env->GetStringUTFChars(_nnApiDeviceName, NULL);
46 void *handle =
47 BenchmarkModel::create(modelFileName, _useNnApi,
48 _enableIntermediateTensorsDump, nnApiDeviceName);
49 env->ReleaseStringUTFChars(_modelFileName, modelFileName);
50 if (_nnApiDeviceName != NULL) {
51 env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
52 }
53
54 return (jlong)(uintptr_t)handle;
55 }
56
57 extern "C"
58 JNIEXPORT void
59 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(JNIEnv * env,jobject,jlong _modelHandle)60 Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(
61 JNIEnv *env,
62 jobject /* this */,
63 jlong _modelHandle) {
64 BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
65 delete(model);
66 }
67
68 extern "C"
69 JNIEXPORT jboolean
70 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(JNIEnv * env,jobject,jlong _modelHandle,jintArray _inputShape)71 Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(
72 JNIEnv *env,
73 jobject /* this */,
74 jlong _modelHandle,
75 jintArray _inputShape) {
76 BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
77 jint* shapePtr = env->GetIntArrayElements(_inputShape, nullptr);
78 jsize shapeLen = env->GetArrayLength(_inputShape);
79
80 std::vector<int> shape(shapePtr, shapePtr + shapeLen);
81 return model->resizeInputTensors(std::move(shape));
82 }
83
84 /** RAII container for a list of InferenceInOutSequence to handle JNI data release in destructor. */
85 class InferenceInOutSequenceList {
86 public:
87 InferenceInOutSequenceList(JNIEnv *env,
88 const jobject& inOutDataList,
89 bool expectGoldenOutputs);
90 ~InferenceInOutSequenceList();
91
isValid() const92 bool isValid() const { return mValid; }
93
data() const94 const std::vector<InferenceInOutSequence>& data() const { return mData; }
95
96 private:
97 JNIEnv *mEnv; // not owned.
98
99 std::vector<InferenceInOutSequence> mData;
100 std::vector<jbyteArray> mInputArrays;
101 std::vector<jobjectArray> mOutputArrays;
102 bool mValid;
103 };
104
InferenceInOutSequenceList(JNIEnv * env,const jobject & inOutDataList,bool expectGoldenOutputs)105 InferenceInOutSequenceList::InferenceInOutSequenceList(JNIEnv *env,
106 const jobject& inOutDataList,
107 bool expectGoldenOutputs)
108 : mEnv(env), mValid(false) {
109
110 jclass list_class = env->FindClass("java/util/List");
111 if (list_class == nullptr) { return; }
112 jmethodID list_size = env->GetMethodID(list_class, "size", "()I");
113 if (list_size == nullptr) { return; }
114 jmethodID list_get = env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;");
115 if (list_get == nullptr) { return; }
116 jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
117 if (list_add == nullptr) { return; }
118
119 jclass inOutSeq_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOutSequence");
120 if (inOutSeq_class == nullptr) { return; }
121 jmethodID inOutSeq_size = env->GetMethodID(inOutSeq_class, "size", "()I");
122 if (inOutSeq_size == nullptr) { return; }
123 jmethodID inOutSeq_get = env->GetMethodID(inOutSeq_class, "get",
124 "(I)Lcom/android/nn/benchmark/core/InferenceInOut;");
125 if (inOutSeq_get == nullptr) { return; }
126
127 jclass inout_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut");
128 if (inout_class == nullptr) { return; }
129 jfieldID inout_input = env->GetFieldID(inout_class, "mInput", "[B");
130 if (inout_input == nullptr) { return; }
131 jfieldID inout_expectedOutputs = env->GetFieldID(inout_class, "mExpectedOutputs", "[[B");
132 if (inout_expectedOutputs == nullptr) { return; }
133 jfieldID inout_inputCreator = env->GetFieldID(inout_class, "mInputCreator",
134 "Lcom/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface;");
135 if (inout_inputCreator == nullptr) { return; }
136
137
138
139 // Fetch input/output arrays
140 size_t data_count = mEnv->CallIntMethod(inOutDataList, list_size);
141 if (env->ExceptionCheck()) { return; }
142 mData.reserve(data_count);
143
144 jclass inputCreator_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface");
145 if (inputCreator_class == nullptr) { return; }
146 jmethodID createInput_method = env->GetMethodID(inputCreator_class, "createInput", "(Ljava/nio/ByteBuffer;)V");
147 if (createInput_method == nullptr) { return; }
148
149 for (int seq_index = 0; seq_index < data_count; ++seq_index) {
150 jobject inOutSeq = mEnv->CallObjectMethod(inOutDataList, list_get, seq_index);
151 if (mEnv->ExceptionCheck()) { return; }
152
153 size_t seqLen = mEnv->CallIntMethod(inOutSeq, inOutSeq_size);
154 if (mEnv->ExceptionCheck()) { return; }
155
156 mData.push_back(InferenceInOutSequence{});
157 auto& seq = mData.back();
158 seq.reserve(seqLen);
159 for (int i = 0; i < seqLen; ++i) {
160 jobject inout = mEnv->CallObjectMethod(inOutSeq, inOutSeq_get, i);
161 if (mEnv->ExceptionCheck()) { return; }
162
163 uint8_t* input_data = nullptr;
164 size_t input_len = 0;
165 std::function<bool(uint8_t*, size_t)> inputCreator;
166 jbyteArray input = static_cast<jbyteArray>(
167 mEnv->GetObjectField(inout, inout_input));
168 mInputArrays.push_back(input);
169 if (input != nullptr) {
170 input_data = reinterpret_cast<uint8_t*>(
171 mEnv->GetByteArrayElements(input, NULL));
172 input_len = mEnv->GetArrayLength(input);
173 } else {
174 inputCreator = [env, inout, inout_inputCreator, createInput_method](
175 uint8_t* buffer, size_t length) {
176 jobject byteBuffer = env->NewDirectByteBuffer(buffer, length);
177 if (byteBuffer == nullptr) { return false; }
178 jobject creator = env->GetObjectField(inout, inout_inputCreator);
179 if (creator == nullptr) { return false; }
180 env->CallVoidMethod(creator, createInput_method, byteBuffer);
181 env->DeleteLocalRef(byteBuffer);
182 if (env->ExceptionCheck()) { return false; }
183 return true;
184 };
185 }
186
187 jobjectArray expectedOutputs = static_cast<jobjectArray>(
188 mEnv->GetObjectField(inout, inout_expectedOutputs));
189 mOutputArrays.push_back(expectedOutputs);
190 seq.push_back({input_data, input_len, {}, inputCreator});
191
192 // Add expected output to sequence added above
193 if (expectedOutputs != nullptr) {
194 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
195 auto& outputs = seq.back().outputs;
196 outputs.reserve(expectedOutputsLength);
197
198 for (jsize j = 0;j < expectedOutputsLength; ++j) {
199 jbyteArray expectedOutput =
200 static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
201 if (env->ExceptionCheck()) {
202 return;
203 }
204 if (expectedOutput == nullptr) {
205 jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
206 mEnv->ThrowNew(iaeClass, "Null expected output array");
207 return;
208 }
209
210 uint8_t *expectedOutput_data = reinterpret_cast<uint8_t*>(
211 mEnv->GetByteArrayElements(expectedOutput, NULL));
212 size_t expectedOutput_len = mEnv->GetArrayLength(expectedOutput);
213 outputs.push_back({ expectedOutput_data, expectedOutput_len});
214 }
215 } else {
216 if (expectGoldenOutputs) {
217 jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
218 mEnv->ThrowNew(iaeClass, "Expected golden output for every input");
219 return;
220 }
221 }
222 }
223 }
224 mValid = true;
225 }
226
~InferenceInOutSequenceList()227 InferenceInOutSequenceList::~InferenceInOutSequenceList() {
228 // Note that we may land here with a pending JNI exception so cannot call
229 // java objects.
230 int arrayIndex = 0;
231 for (int seq_index = 0; seq_index < mData.size(); ++seq_index) {
232 for (int i = 0; i < mData[seq_index].size(); ++i) {
233 jbyteArray input = mInputArrays[arrayIndex];
234 if (input != nullptr) {
235 mEnv->ReleaseByteArrayElements(
236 input, reinterpret_cast<jbyte*>(mData[seq_index][i].input), JNI_ABORT);
237 }
238 jobjectArray expectedOutputs = mOutputArrays[arrayIndex];
239 if (expectedOutputs != nullptr) {
240 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
241 if (expectedOutputsLength != mData[seq_index][i].outputs.size()) {
242 // Should not happen? :)
243 jclass iaeClass = mEnv->FindClass("java/lang/IllegalStateException");
244 mEnv->ThrowNew(iaeClass, "Mismatch of the size of expected outputs jni array "
245 "and internal array of its bufers");
246 return;
247 }
248
249 for (jsize j = 0;j < expectedOutputsLength; ++j) {
250 jbyteArray expectedOutput = static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
251 mEnv->ReleaseByteArrayElements(
252 expectedOutput, reinterpret_cast<jbyte*>(mData[seq_index][i].outputs[j].ptr),
253 JNI_ABORT);
254 }
255 }
256 arrayIndex++;
257 }
258 }
259 }
260
261 extern "C"
262 JNIEXPORT jboolean
263 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(JNIEnv * env,jobject,jlong _modelHandle,jobject inOutDataList,jobject resultList,jint inferencesSeqMaxCount,jfloat timeoutSec,jint flags)264 Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(
265 JNIEnv *env,
266 jobject /* this */,
267 jlong _modelHandle,
268 jobject inOutDataList,
269 jobject resultList,
270 jint inferencesSeqMaxCount,
271 jfloat timeoutSec,
272 jint flags) {
273
274 BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
275
276 jclass list_class = env->FindClass("java/util/List");
277 if (list_class == nullptr) { return false; }
278 jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
279 if (list_add == nullptr) { return false; }
280
281 jclass result_class = env->FindClass("com/android/nn/benchmark/core/InferenceResult");
282 if (result_class == nullptr) { return false; }
283 jmethodID result_ctor = env->GetMethodID(result_class, "<init>", "(F[F[F[[BII)V");
284 if (result_ctor == nullptr) { return false; }
285
286 std::vector<InferenceResult> result;
287
288 const bool expectGoldenOutputs = (flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0;
289 InferenceInOutSequenceList data(env, inOutDataList, expectGoldenOutputs);
290 if (!data.isValid()) {
291 return false;
292 }
293
294 // TODO: Remove success boolean from this method and throw an exception in case of problems
295 bool success = model->benchmark(data.data(), inferencesSeqMaxCount, timeoutSec, flags, &result);
296
297 // Generate results
298 if (success) {
299 for (const InferenceResult &rentry : result) {
300 jobjectArray inferenceOutputs = nullptr;
301 jfloatArray meanSquareErrorArray = nullptr;
302 jfloatArray maxSingleErrorArray = nullptr;
303
304 if ((flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0) {
305 meanSquareErrorArray = env->NewFloatArray(rentry.meanSquareErrors.size());
306 if (env->ExceptionCheck()) { return false; }
307 maxSingleErrorArray = env->NewFloatArray(rentry.maxSingleErrors.size());
308 if (env->ExceptionCheck()) { return false; }
309 {
310 jfloat *bytes = env->GetFloatArrayElements(meanSquareErrorArray, nullptr);
311 memcpy(bytes,
312 &rentry.meanSquareErrors[0],
313 rentry.meanSquareErrors.size() * sizeof(float));
314 env->ReleaseFloatArrayElements(meanSquareErrorArray, bytes, 0);
315 }
316 {
317 jfloat *bytes = env->GetFloatArrayElements(maxSingleErrorArray, nullptr);
318 memcpy(bytes,
319 &rentry.maxSingleErrors[0],
320 rentry.maxSingleErrors.size() * sizeof(float));
321 env->ReleaseFloatArrayElements(maxSingleErrorArray, bytes, 0);
322 }
323 }
324
325 if ((flags & FLAG_DISCARD_INFERENCE_OUTPUT) == 0) {
326 jclass byteArrayClass = env->FindClass("[B");
327
328 inferenceOutputs = env->NewObjectArray(
329 rentry.inferenceOutputs.size(),
330 byteArrayClass, nullptr);
331
332 for (int i = 0;i < rentry.inferenceOutputs.size();++i) {
333 jbyteArray inferenceOutput = nullptr;
334 inferenceOutput = env->NewByteArray(rentry.inferenceOutputs[i].size());
335 if (env->ExceptionCheck()) { return false; }
336 jbyte *bytes = env->GetByteArrayElements(inferenceOutput, nullptr);
337 memcpy(bytes, &rentry.inferenceOutputs[i][0], rentry.inferenceOutputs[i].size());
338 env->ReleaseByteArrayElements(inferenceOutput, bytes, 0);
339 env->SetObjectArrayElement(inferenceOutputs, i, inferenceOutput);
340 }
341 }
342
343 jobject object = env->NewObject(
344 result_class, result_ctor, rentry.computeTimeSec,
345 meanSquareErrorArray, maxSingleErrorArray, inferenceOutputs,
346 rentry.inputOutputSequenceIndex, rentry.inputOutputIndex);
347 if (env->ExceptionCheck() || object == NULL) { return false; }
348
349 env->CallBooleanMethod(resultList, list_add, object);
350 if (env->ExceptionCheck()) { return false; }
351 }
352 }
353
354 return success;
355 }
356
357 extern "C"
358 JNIEXPORT void
359 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(JNIEnv * env,jobject,jlong _modelHandle,jstring dumpPath,jobject inOutDataList)360 Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(
361 JNIEnv *env,
362 jobject /* this */,
363 jlong _modelHandle,
364 jstring dumpPath,
365 jobject inOutDataList) {
366
367 BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
368
369 InferenceInOutSequenceList data(env, inOutDataList, /*expectGoldenOutputs=*/false);
370 if (!data.isValid()) {
371 return;
372 }
373
374 const char *dumpPathStr = env->GetStringUTFChars(dumpPath, JNI_FALSE);
375 model->dumpAllLayers(dumpPathStr, data.data());
376 env->ReleaseStringUTFChars(dumpPath, dumpPathStr);
377 }
378