1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <android/log.h>
17 #include <jni.h>
18 #include <stdint.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include <cstdint>
22
23 #include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
24 #include "tensorflow/examples/android/jni/object_tracking/image.h"
25 #include "tensorflow/examples/android/jni/object_tracking/jni_utils.h"
26 #include "tensorflow/examples/android/jni/object_tracking/time_log.h"
27
28 #include "tensorflow/examples/android/jni/object_tracking/config.h"
29 #include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
30
31 namespace tf_tracking {
32
33 #define OBJECT_TRACKER_METHOD(METHOD_NAME) \
34 Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME // NOLINT
35
36 JniLongField object_tracker_field("nativeObjectTracker");
37
get_object_tracker(JNIEnv * env,jobject thiz)38 ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) {
39 ObjectTracker* const object_tracker =
40 reinterpret_cast<ObjectTracker*>(object_tracker_field.get(env, thiz));
41 CHECK_ALWAYS(object_tracker != NULL, "null object tracker!");
42 return object_tracker;
43 }
44
set_object_tracker(JNIEnv * env,jobject thiz,const ObjectTracker * object_tracker)45 void set_object_tracker(JNIEnv* env, jobject thiz,
46 const ObjectTracker* object_tracker) {
47 object_tracker_field.set(env, thiz,
48 reinterpret_cast<intptr_t>(object_tracker));
49 }
50
51 #ifdef __cplusplus
52 extern "C" {
53 #endif
54 JNIEXPORT
55 void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
56 jint width, jint height,
57 jboolean always_track);
58
59 JNIEXPORT
60 void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
61 jobject thiz);
62
63 JNIEXPORT
64 void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
65 JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
66 jfloat x2, jfloat y2, jbyteArray frame_data);
67
68 JNIEXPORT
69 void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
70 JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
71 jfloat x2, jfloat y2, jlong timestamp);
72
73 JNIEXPORT
74 void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
75 JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
76 jfloat x2, jfloat y2);
77
78 JNIEXPORT
79 jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
80 jstring object_id);
81
82 JNIEXPORT
83 jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
84 jobject thiz,
85 jstring object_id);
86
87 JNIEXPORT
88 jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
89 jobject thiz,
90 jstring object_id);
91
92 JNIEXPORT
93 jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
94 jobject thiz,
95 jstring object_id);
96
97 JNIEXPORT
98 jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
99 jstring object_id);
100
101 JNIEXPORT
102 void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
103 JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array);
104
105 JNIEXPORT
106 void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
107 jbyteArray y_data,
108 jbyteArray uv_data,
109 jlong timestamp,
110 jfloatArray vg_matrix_2x3);
111
112 JNIEXPORT
113 void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
114 jstring object_id);
115
116 JNIEXPORT
117 jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
118 JNIEnv* env, jobject thiz, jfloat scale_factor);
119
120 JNIEXPORT
121 jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
122 JNIEnv* env, jobject thiz, jboolean only_found_);
123
124 JNIEXPORT
125 void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
126 JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
127 jfloat position_y1, jfloat position_x2, jfloat position_y2,
128 jfloatArray delta);
129
130 JNIEXPORT
131 void JNICALL OBJECT_TRACKER_METHOD(drawNative)(JNIEnv* env, jobject obj,
132 jint view_width,
133 jint view_height,
134 jfloatArray delta);
135
136 JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
137 JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
138 jbyteArray input, jint factor, jbyteArray output);
139
140 #ifdef __cplusplus
141 }
142 #endif
143
144 JNIEXPORT
OBJECT_TRACKER_METHOD(initNative)145 void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
146 jint width, jint height,
147 jboolean always_track) {
148 LOGI("Initializing object tracker. %dx%d @%p", width, height, thiz);
149 const Size image_size(width, height);
150 TrackerConfig* const tracker_config = new TrackerConfig(image_size);
151 tracker_config->always_track = always_track;
152
153 // XXX detector
154 ObjectTracker* const tracker = new ObjectTracker(tracker_config, NULL);
155 set_object_tracker(env, thiz, tracker);
156 LOGI("Initialized!");
157
158 CHECK_ALWAYS(get_object_tracker(env, thiz) == tracker,
159 "Failure to set hand tracker!");
160 }
161
162 JNIEXPORT
OBJECT_TRACKER_METHOD(releaseMemoryNative)163 void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
164 jobject thiz) {
165 delete get_object_tracker(env, thiz);
166 set_object_tracker(env, thiz, NULL);
167 }
168
169 JNIEXPORT
OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)170 void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
171 JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
172 jfloat x2, jfloat y2, jbyteArray frame_data) {
173 const char* const id_str = env->GetStringUTFChars(object_id, 0);
174
175 LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
176 x2, y2);
177
178 jboolean iCopied = JNI_FALSE;
179
180 // Copy image into currFrame.
181 jbyte* pixels = env->GetByteArrayElements(frame_data, &iCopied);
182
183 BoundingBox bounding_box(x1, y1, x2, y2);
184 get_object_tracker(env, thiz)->RegisterNewObjectWithAppearance(
185 id_str, reinterpret_cast<const uint8_t*>(pixels), bounding_box);
186
187 env->ReleaseByteArrayElements(frame_data, pixels, JNI_ABORT);
188
189 env->ReleaseStringUTFChars(object_id, id_str);
190 }
191
192 JNIEXPORT
OBJECT_TRACKER_METHOD(setPreviousPositionNative)193 void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
194 JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
195 jfloat x2, jfloat y2, jlong timestamp) {
196 const char* const id_str = env->GetStringUTFChars(object_id, 0);
197
198 LOGI(
199 "Registering the position of %s at %.2f,%.2f,%.2f,%.2f"
200 " at time %lld",
201 id_str, x1, y1, x2, y2, static_cast<int64_t>(timestamp));
202
203 get_object_tracker(env, thiz)->SetPreviousPositionOfObject(
204 id_str, BoundingBox(x1, y1, x2, y2), timestamp);
205
206 env->ReleaseStringUTFChars(object_id, id_str);
207 }
208
209 JNIEXPORT
OBJECT_TRACKER_METHOD(setCurrentPositionNative)210 void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
211 JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
212 jfloat x2, jfloat y2) {
213 const char* const id_str = env->GetStringUTFChars(object_id, 0);
214
215 LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
216 x2, y2);
217
218 get_object_tracker(env, thiz)->SetCurrentPositionOfObject(
219 id_str, BoundingBox(x1, y1, x2, y2));
220
221 env->ReleaseStringUTFChars(object_id, id_str);
222 }
223
224 JNIEXPORT
OBJECT_TRACKER_METHOD(haveObject)225 jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
226 jstring object_id) {
227 const char* const id_str = env->GetStringUTFChars(object_id, 0);
228
229 const bool haveObject = get_object_tracker(env, thiz)->HaveObject(id_str);
230 env->ReleaseStringUTFChars(object_id, id_str);
231 return haveObject;
232 }
233
234 JNIEXPORT
OBJECT_TRACKER_METHOD(isObjectVisible)235 jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
236 jobject thiz,
237 jstring object_id) {
238 const char* const id_str = env->GetStringUTFChars(object_id, 0);
239
240 const bool visible = get_object_tracker(env, thiz)->IsObjectVisible(id_str);
241 env->ReleaseStringUTFChars(object_id, id_str);
242 return visible;
243 }
244
245 JNIEXPORT
OBJECT_TRACKER_METHOD(getModelIdNative)246 jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
247 jobject thiz,
248 jstring object_id) {
249 const char* const id_str = env->GetStringUTFChars(object_id, 0);
250 const TrackedObject* const object =
251 get_object_tracker(env, thiz)->GetObject(id_str);
252 env->ReleaseStringUTFChars(object_id, id_str);
253 jstring model_name = env->NewStringUTF(object->GetModel()->GetName().c_str());
254 return model_name;
255 }
256
257 JNIEXPORT
OBJECT_TRACKER_METHOD(getCurrentCorrelation)258 jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
259 jobject thiz,
260 jstring object_id) {
261 const char* const id_str = env->GetStringUTFChars(object_id, 0);
262
263 const float correlation =
264 get_object_tracker(env, thiz)->GetObject(id_str)->GetCorrelation();
265 env->ReleaseStringUTFChars(object_id, id_str);
266 return correlation;
267 }
268
269 JNIEXPORT
OBJECT_TRACKER_METHOD(getMatchScore)270 jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
271 jstring object_id) {
272 const char* const id_str = env->GetStringUTFChars(object_id, 0);
273
274 const float match_score =
275 get_object_tracker(env, thiz)->GetObject(id_str)->GetMatchScore().value;
276 env->ReleaseStringUTFChars(object_id, id_str);
277 return match_score;
278 }
279
280 JNIEXPORT
OBJECT_TRACKER_METHOD(getTrackedPositionNative)281 void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
282 JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array) {
283 jboolean iCopied = JNI_FALSE;
284 const char* const id_str = env->GetStringUTFChars(object_id, 0);
285
286 const BoundingBox bounding_box =
287 get_object_tracker(env, thiz)->GetObject(id_str)->GetPosition();
288 env->ReleaseStringUTFChars(object_id, id_str);
289
290 jfloat* rect = env->GetFloatArrayElements(rect_array, &iCopied);
291 bounding_box.CopyToArray(reinterpret_cast<float*>(rect));
292 env->ReleaseFloatArrayElements(rect_array, rect, 0);
293 }
294
295 JNIEXPORT
OBJECT_TRACKER_METHOD(nextFrameNative)296 void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
297 jbyteArray y_data,
298 jbyteArray uv_data,
299 jlong timestamp,
300 jfloatArray vg_matrix_2x3) {
301 TimeLog("Starting object tracker");
302
303 jboolean iCopied = JNI_FALSE;
304
305 float vision_gyro_matrix_array[6];
306 jfloat* jmat = NULL;
307
308 if (vg_matrix_2x3 != NULL) {
309 // Copy the alignment matrix into a float array.
310 jmat = env->GetFloatArrayElements(vg_matrix_2x3, &iCopied);
311 for (int i = 0; i < 6; ++i) {
312 vision_gyro_matrix_array[i] = static_cast<float>(jmat[i]);
313 }
314 }
315 // Copy image into currFrame.
316 jbyte* pixels = env->GetByteArrayElements(y_data, &iCopied);
317 jbyte* uv_pixels =
318 uv_data != NULL ? env->GetByteArrayElements(uv_data, &iCopied) : NULL;
319
320 TimeLog("Got elements");
321
322 // Add the frame to the object tracker object.
323 get_object_tracker(env, thiz)->NextFrame(
324 reinterpret_cast<uint8_t*>(pixels), reinterpret_cast<uint8_t*>(uv_pixels),
325 timestamp, vg_matrix_2x3 != NULL ? vision_gyro_matrix_array : NULL);
326
327 env->ReleaseByteArrayElements(y_data, pixels, JNI_ABORT);
328
329 if (uv_data != NULL) {
330 env->ReleaseByteArrayElements(uv_data, uv_pixels, JNI_ABORT);
331 }
332
333 if (vg_matrix_2x3 != NULL) {
334 env->ReleaseFloatArrayElements(vg_matrix_2x3, jmat, JNI_ABORT);
335 }
336
337 TimeLog("Released elements");
338
339 PrintTimeLog();
340 ResetTimeLog();
341 }
342
343 JNIEXPORT
OBJECT_TRACKER_METHOD(forgetNative)344 void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
345 jstring object_id) {
346 const char* const id_str = env->GetStringUTFChars(object_id, 0);
347
348 get_object_tracker(env, thiz)->ForgetTarget(id_str);
349
350 env->ReleaseStringUTFChars(object_id, id_str);
351 }
352
353 JNIEXPORT
OBJECT_TRACKER_METHOD(getKeypointsNative)354 jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
355 JNIEnv* env, jobject thiz, jboolean only_found) {
356 jfloat keypoint_arr[kMaxKeypoints * kKeypointStep];
357
358 const int number_of_keypoints =
359 get_object_tracker(env, thiz)->GetKeypoints(only_found, keypoint_arr);
360
361 // Create and return the array that will be passed back to Java.
362 jfloatArray keypoints =
363 env->NewFloatArray(number_of_keypoints * kKeypointStep);
364 if (keypoints == NULL) {
365 LOGE("null array!");
366 return NULL;
367 }
368 env->SetFloatArrayRegion(keypoints, 0, number_of_keypoints * kKeypointStep,
369 keypoint_arr);
370
371 return keypoints;
372 }
373
374 JNIEXPORT
OBJECT_TRACKER_METHOD(getKeypointsPacked)375 jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
376 JNIEnv* env, jobject thiz, jfloat scale_factor) {
377 // 2 bytes to a uint16_t and two pairs of xy coordinates per keypoint.
378 const int bytes_per_keypoint = sizeof(uint16_t) * 2 * 2;
379 jbyte keypoint_arr[kMaxKeypoints * bytes_per_keypoint];
380
381 const int number_of_keypoints =
382 get_object_tracker(env, thiz)->GetKeypointsPacked(
383 reinterpret_cast<uint16_t*>(keypoint_arr), scale_factor);
384
385 // Create and return the array that will be passed back to Java.
386 jbyteArray keypoints =
387 env->NewByteArray(number_of_keypoints * bytes_per_keypoint);
388
389 if (keypoints == NULL) {
390 LOGE("null array!");
391 return NULL;
392 }
393
394 env->SetByteArrayRegion(
395 keypoints, 0, number_of_keypoints * bytes_per_keypoint, keypoint_arr);
396
397 return keypoints;
398 }
399
400 JNIEXPORT
OBJECT_TRACKER_METHOD(getCurrentPositionNative)401 void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
402 JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
403 jfloat position_y1, jfloat position_x2, jfloat position_y2,
404 jfloatArray delta) {
405 jfloat point_arr[4];
406
407 const BoundingBox new_position = get_object_tracker(env, thiz)->TrackBox(
408 BoundingBox(position_x1, position_y1, position_x2, position_y2),
409 timestamp);
410
411 new_position.CopyToArray(point_arr);
412 env->SetFloatArrayRegion(delta, 0, 4, point_arr);
413 }
414
415 JNIEXPORT
OBJECT_TRACKER_METHOD(drawNative)416 void JNICALL OBJECT_TRACKER_METHOD(drawNative)(
417 JNIEnv* env, jobject thiz, jint view_width, jint view_height,
418 jfloatArray frame_to_canvas_arr) {
419 ObjectTracker* object_tracker = get_object_tracker(env, thiz);
420 if (object_tracker != NULL) {
421 jfloat* frame_to_canvas =
422 env->GetFloatArrayElements(frame_to_canvas_arr, NULL);
423
424 object_tracker->Draw(view_width, view_height, frame_to_canvas);
425 env->ReleaseFloatArrayElements(frame_to_canvas_arr, frame_to_canvas,
426 JNI_ABORT);
427 }
428 }
429
OBJECT_TRACKER_METHOD(downsampleImageNative)430 JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
431 JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
432 jbyteArray input, jint factor, jbyteArray output) {
433 if (input == NULL || output == NULL) {
434 LOGW("Received null arrays, hopefully this is a test!");
435 return;
436 }
437
438 jbyte* const input_array = env->GetByteArrayElements(input, 0);
439 jbyte* const output_array = env->GetByteArrayElements(output, 0);
440
441 {
442 tf_tracking::Image<uint8_t> full_image(
443 width, height, reinterpret_cast<uint8_t*>(input_array), false);
444
445 const int new_width = (width + factor - 1) / factor;
446 const int new_height = (height + factor - 1) / factor;
447
448 tf_tracking::Image<uint8_t> downsampled_image(
449 new_width, new_height, reinterpret_cast<uint8_t*>(output_array), false);
450
451 downsampled_image.DownsampleAveraged(
452 reinterpret_cast<uint8_t*>(input_array), row_stride, factor);
453 }
454
455 env->ReleaseByteArrayElements(input, input_array, JNI_ABORT);
456 env->ReleaseByteArrayElements(output, output_array, 0);
457 }
458
459 } // namespace tf_tracking
460