• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "run_tflite.h"
18 
19 #include <jni.h>
20 #include <string>
21 #include <iomanip>
22 #include <sstream>
23 #include <fcntl.h>
24 
25 #include <android/asset_manager_jni.h>
26 #include <android/log.h>
27 #include <android/sharedmem.h>
28 #include <sys/mman.h>
29 
30 
31 extern "C"
32 JNIEXPORT jlong
33 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_initModel(JNIEnv * env,jobject,jstring _modelFileName,jboolean _useNnApi,jboolean _enableIntermediateTensorsDump,jstring _nnApiDeviceName)34 Java_com_android_nn_benchmark_core_NNTestBase_initModel(
35         JNIEnv *env,
36         jobject /* this */,
37         jstring _modelFileName,
38         jboolean _useNnApi,
39         jboolean _enableIntermediateTensorsDump,
40         jstring _nnApiDeviceName) {
41     const char *modelFileName = env->GetStringUTFChars(_modelFileName, NULL);
42     const char *nnApiDeviceName =
43         _nnApiDeviceName == NULL
44             ? NULL
45             : env->GetStringUTFChars(_nnApiDeviceName, NULL);
46     void *handle =
47         BenchmarkModel::create(modelFileName, _useNnApi,
48                                _enableIntermediateTensorsDump, nnApiDeviceName);
49     env->ReleaseStringUTFChars(_modelFileName, modelFileName);
50     if (_nnApiDeviceName != NULL) {
51         env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
52     }
53 
54     return (jlong)(uintptr_t)handle;
55 }
56 
57 extern "C"
58 JNIEXPORT void
59 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(JNIEnv * env,jobject,jlong _modelHandle)60 Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(
61         JNIEnv *env,
62         jobject /* this */,
63         jlong _modelHandle) {
64     BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
65     delete(model);
66 }
67 
68 extern "C"
69 JNIEXPORT jboolean
70 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(JNIEnv * env,jobject,jlong _modelHandle,jintArray _inputShape)71 Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(
72         JNIEnv *env,
73         jobject /* this */,
74         jlong _modelHandle,
75         jintArray _inputShape) {
76     BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
77     jint* shapePtr = env->GetIntArrayElements(_inputShape, nullptr);
78     jsize shapeLen = env->GetArrayLength(_inputShape);
79 
80     std::vector<int> shape(shapePtr, shapePtr + shapeLen);
81     return model->resizeInputTensors(std::move(shape));
82 }
83 
84 /** RAII container for a list of InferenceInOutSequence to handle JNI data release in destructor. */
85 class InferenceInOutSequenceList {
86 public:
87     InferenceInOutSequenceList(JNIEnv *env,
88                                const jobject& inOutDataList,
89                                bool expectGoldenOutputs);
90     ~InferenceInOutSequenceList();
91 
isValid() const92     bool isValid() const { return mValid; }
93 
data() const94     const std::vector<InferenceInOutSequence>& data() const { return mData; }
95 
96 private:
97     JNIEnv *mEnv;  // not owned.
98 
99     std::vector<InferenceInOutSequence> mData;
100     std::vector<jbyteArray> mInputArrays;
101     std::vector<jobjectArray> mOutputArrays;
102     bool mValid;
103 };
104 
InferenceInOutSequenceList(JNIEnv * env,const jobject & inOutDataList,bool expectGoldenOutputs)105 InferenceInOutSequenceList::InferenceInOutSequenceList(JNIEnv *env,
106                                                        const jobject& inOutDataList,
107                                                        bool expectGoldenOutputs)
108     : mEnv(env), mValid(false) {
109 
110     jclass list_class = env->FindClass("java/util/List");
111     if (list_class == nullptr) { return; }
112     jmethodID list_size = env->GetMethodID(list_class, "size", "()I");
113     if (list_size == nullptr) { return; }
114     jmethodID list_get = env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;");
115     if (list_get == nullptr) { return; }
116     jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
117     if (list_add == nullptr) { return; }
118 
119     jclass inOutSeq_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOutSequence");
120     if (inOutSeq_class == nullptr) { return; }
121     jmethodID inOutSeq_size = env->GetMethodID(inOutSeq_class, "size", "()I");
122     if (inOutSeq_size == nullptr) { return; }
123     jmethodID inOutSeq_get = env->GetMethodID(inOutSeq_class, "get",
124                                               "(I)Lcom/android/nn/benchmark/core/InferenceInOut;");
125     if (inOutSeq_get == nullptr) { return; }
126 
127     jclass inout_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut");
128     if (inout_class == nullptr) { return; }
129     jfieldID inout_input = env->GetFieldID(inout_class, "mInput", "[B");
130     if (inout_input == nullptr) { return; }
131     jfieldID inout_expectedOutputs = env->GetFieldID(inout_class, "mExpectedOutputs", "[[B");
132     if (inout_expectedOutputs == nullptr) { return; }
133     jfieldID inout_inputCreator = env->GetFieldID(inout_class, "mInputCreator",
134             "Lcom/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface;");
135     if (inout_inputCreator == nullptr) { return; }
136 
137 
138 
139     // Fetch input/output arrays
140     size_t data_count = mEnv->CallIntMethod(inOutDataList, list_size);
141     if (env->ExceptionCheck()) { return; }
142     mData.reserve(data_count);
143 
144     jclass inputCreator_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface");
145     if (inputCreator_class == nullptr) { return; }
146     jmethodID createInput_method = env->GetMethodID(inputCreator_class, "createInput", "(Ljava/nio/ByteBuffer;)V");
147     if (createInput_method == nullptr) { return; }
148 
149     for (int seq_index = 0; seq_index < data_count; ++seq_index) {
150         jobject inOutSeq = mEnv->CallObjectMethod(inOutDataList, list_get, seq_index);
151         if (mEnv->ExceptionCheck()) { return; }
152 
153         size_t seqLen = mEnv->CallIntMethod(inOutSeq, inOutSeq_size);
154         if (mEnv->ExceptionCheck()) { return; }
155 
156         mData.push_back(InferenceInOutSequence{});
157         auto& seq = mData.back();
158         seq.reserve(seqLen);
159         for (int i = 0; i < seqLen; ++i) {
160             jobject inout = mEnv->CallObjectMethod(inOutSeq, inOutSeq_get, i);
161             if (mEnv->ExceptionCheck()) { return; }
162 
163             uint8_t* input_data = nullptr;
164             size_t input_len = 0;
165             std::function<bool(uint8_t*, size_t)> inputCreator;
166             jbyteArray input = static_cast<jbyteArray>(
167                     mEnv->GetObjectField(inout, inout_input));
168             mInputArrays.push_back(input);
169             if (input != nullptr) {
170                 input_data = reinterpret_cast<uint8_t*>(
171                         mEnv->GetByteArrayElements(input, NULL));
172                 input_len = mEnv->GetArrayLength(input);
173             } else {
174                 inputCreator = [env, inout, inout_inputCreator, createInput_method](
175                         uint8_t* buffer, size_t length) {
176                     jobject byteBuffer = env->NewDirectByteBuffer(buffer, length);
177                     if (byteBuffer == nullptr) { return false; }
178                     jobject creator = env->GetObjectField(inout, inout_inputCreator);
179                     if (creator == nullptr) { return false; }
180                     env->CallVoidMethod(creator, createInput_method, byteBuffer);
181                     env->DeleteLocalRef(byteBuffer);
182                     if (env->ExceptionCheck()) { return false; }
183                     return true;
184                 };
185             }
186 
187             jobjectArray expectedOutputs = static_cast<jobjectArray>(
188                     mEnv->GetObjectField(inout, inout_expectedOutputs));
189             mOutputArrays.push_back(expectedOutputs);
190             seq.push_back({input_data, input_len, {}, inputCreator});
191 
192             // Add expected output to sequence added above
193             if (expectedOutputs != nullptr) {
194                 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
195                 auto& outputs = seq.back().outputs;
196                 outputs.reserve(expectedOutputsLength);
197 
198                 for (jsize j = 0;j < expectedOutputsLength; ++j) {
199                     jbyteArray expectedOutput =
200                             static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
201                     if (env->ExceptionCheck()) {
202                         return;
203                     }
204                     if (expectedOutput == nullptr) {
205                         jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
206                         mEnv->ThrowNew(iaeClass, "Null expected output array");
207                         return;
208                     }
209 
210                     uint8_t *expectedOutput_data = reinterpret_cast<uint8_t*>(
211                                         mEnv->GetByteArrayElements(expectedOutput, NULL));
212                     size_t expectedOutput_len = mEnv->GetArrayLength(expectedOutput);
213                     outputs.push_back({ expectedOutput_data, expectedOutput_len});
214                 }
215             } else {
216                 if (expectGoldenOutputs) {
217                     jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
218                     mEnv->ThrowNew(iaeClass, "Expected golden output for every input");
219                     return;
220                 }
221             }
222         }
223     }
224     mValid = true;
225 }
226 
~InferenceInOutSequenceList()227 InferenceInOutSequenceList::~InferenceInOutSequenceList() {
228     // Note that we may land here with a pending JNI exception so cannot call
229     // java objects.
230     int arrayIndex = 0;
231     for (int seq_index = 0; seq_index < mData.size(); ++seq_index) {
232         for (int i = 0; i < mData[seq_index].size(); ++i) {
233             jbyteArray input = mInputArrays[arrayIndex];
234             if (input != nullptr) {
235                 mEnv->ReleaseByteArrayElements(
236                         input, reinterpret_cast<jbyte*>(mData[seq_index][i].input), JNI_ABORT);
237             }
238             jobjectArray expectedOutputs = mOutputArrays[arrayIndex];
239             if (expectedOutputs != nullptr) {
240                 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
241                 if (expectedOutputsLength != mData[seq_index][i].outputs.size()) {
242                     // Should not happen? :)
243                     jclass iaeClass = mEnv->FindClass("java/lang/IllegalStateException");
244                     mEnv->ThrowNew(iaeClass, "Mismatch of the size of expected outputs jni array "
245                                    "and internal array of its bufers");
246                     return;
247                 }
248 
249                 for (jsize j = 0;j < expectedOutputsLength; ++j) {
250                     jbyteArray expectedOutput = static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
251                     mEnv->ReleaseByteArrayElements(
252                         expectedOutput, reinterpret_cast<jbyte*>(mData[seq_index][i].outputs[j].ptr),
253                         JNI_ABORT);
254                 }
255             }
256             arrayIndex++;
257         }
258     }
259 }
260 
261 extern "C"
262 JNIEXPORT jboolean
263 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(JNIEnv * env,jobject,jlong _modelHandle,jobject inOutDataList,jobject resultList,jint inferencesSeqMaxCount,jfloat timeoutSec,jint flags)264 Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(
265         JNIEnv *env,
266         jobject /* this */,
267         jlong _modelHandle,
268         jobject inOutDataList,
269         jobject resultList,
270         jint inferencesSeqMaxCount,
271         jfloat timeoutSec,
272         jint flags) {
273 
274     BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
275 
276     jclass list_class = env->FindClass("java/util/List");
277     if (list_class == nullptr) { return false; }
278     jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
279     if (list_add == nullptr) { return false; }
280 
281     jclass result_class = env->FindClass("com/android/nn/benchmark/core/InferenceResult");
282     if (result_class == nullptr) { return false; }
283     jmethodID result_ctor = env->GetMethodID(result_class, "<init>", "(F[F[F[[BII)V");
284     if (result_ctor == nullptr) { return false; }
285 
286     std::vector<InferenceResult> result;
287 
288     const bool expectGoldenOutputs = (flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0;
289     InferenceInOutSequenceList data(env, inOutDataList, expectGoldenOutputs);
290     if (!data.isValid()) {
291         return false;
292     }
293 
294     // TODO: Remove success boolean from this method and throw an exception in case of problems
295     bool success = model->benchmark(data.data(), inferencesSeqMaxCount, timeoutSec, flags, &result);
296 
297     // Generate results
298     if (success) {
299         for (const InferenceResult &rentry : result) {
300             jobjectArray inferenceOutputs = nullptr;
301             jfloatArray meanSquareErrorArray = nullptr;
302             jfloatArray maxSingleErrorArray = nullptr;
303 
304             if ((flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0) {
305                 meanSquareErrorArray = env->NewFloatArray(rentry.meanSquareErrors.size());
306                 if (env->ExceptionCheck()) { return false; }
307                 maxSingleErrorArray = env->NewFloatArray(rentry.maxSingleErrors.size());
308                 if (env->ExceptionCheck()) { return false; }
309                 {
310                     jfloat *bytes = env->GetFloatArrayElements(meanSquareErrorArray, nullptr);
311                     memcpy(bytes,
312                            &rentry.meanSquareErrors[0],
313                            rentry.meanSquareErrors.size() * sizeof(float));
314                     env->ReleaseFloatArrayElements(meanSquareErrorArray, bytes, 0);
315                 }
316                 {
317                     jfloat *bytes = env->GetFloatArrayElements(maxSingleErrorArray, nullptr);
318                     memcpy(bytes,
319                            &rentry.maxSingleErrors[0],
320                            rentry.maxSingleErrors.size() * sizeof(float));
321                     env->ReleaseFloatArrayElements(maxSingleErrorArray, bytes, 0);
322                 }
323             }
324 
325             if ((flags & FLAG_DISCARD_INFERENCE_OUTPUT) == 0) {
326                 jclass byteArrayClass = env->FindClass("[B");
327 
328                 inferenceOutputs = env->NewObjectArray(
329                     rentry.inferenceOutputs.size(),
330                     byteArrayClass, nullptr);
331 
332                 for (int i = 0;i < rentry.inferenceOutputs.size();++i) {
333                     jbyteArray inferenceOutput = nullptr;
334                     inferenceOutput = env->NewByteArray(rentry.inferenceOutputs[i].size());
335                     if (env->ExceptionCheck()) { return false; }
336                     jbyte *bytes = env->GetByteArrayElements(inferenceOutput, nullptr);
337                     memcpy(bytes, &rentry.inferenceOutputs[i][0], rentry.inferenceOutputs[i].size());
338                     env->ReleaseByteArrayElements(inferenceOutput, bytes, 0);
339                     env->SetObjectArrayElement(inferenceOutputs, i, inferenceOutput);
340                 }
341             }
342 
343             jobject object = env->NewObject(
344                 result_class, result_ctor, rentry.computeTimeSec,
345                 meanSquareErrorArray, maxSingleErrorArray, inferenceOutputs,
346                 rentry.inputOutputSequenceIndex, rentry.inputOutputIndex);
347             if (env->ExceptionCheck() || object == NULL) { return false; }
348 
349             env->CallBooleanMethod(resultList, list_add, object);
350             if (env->ExceptionCheck()) { return false; }
351         }
352     }
353 
354     return success;
355 }
356 
357 extern "C"
358 JNIEXPORT void
359 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(JNIEnv * env,jobject,jlong _modelHandle,jstring dumpPath,jobject inOutDataList)360 Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(
361         JNIEnv *env,
362         jobject /* this */,
363         jlong _modelHandle,
364         jstring dumpPath,
365         jobject inOutDataList) {
366 
367     BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
368 
369     InferenceInOutSequenceList data(env, inOutDataList, /*expectGoldenOutputs=*/false);
370     if (!data.isValid()) {
371         return;
372     }
373 
374     const char *dumpPathStr = env->GetStringUTFChars(dumpPath, JNI_FALSE);
375     model->dumpAllLayers(dumpPathStr, data.data());
376     env->ReleaseStringUTFChars(dumpPath, dumpPathStr);
377 }
378