• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <android/log.h>
17 #include <jni.h>
18 #include <stdint.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include <cstdint>
22 
23 #include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
24 #include "tensorflow/examples/android/jni/object_tracking/image.h"
25 #include "tensorflow/examples/android/jni/object_tracking/jni_utils.h"
26 #include "tensorflow/examples/android/jni/object_tracking/time_log.h"
27 
28 #include "tensorflow/examples/android/jni/object_tracking/config.h"
29 #include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
30 
31 namespace tf_tracking {
32 
33 #define OBJECT_TRACKER_METHOD(METHOD_NAME) \
34   Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME  // NOLINT
35 
36 JniLongField object_tracker_field("nativeObjectTracker");
37 
get_object_tracker(JNIEnv * env,jobject thiz)38 ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) {
39   ObjectTracker* const object_tracker =
40       reinterpret_cast<ObjectTracker*>(object_tracker_field.get(env, thiz));
41   CHECK_ALWAYS(object_tracker != NULL, "null object tracker!");
42   return object_tracker;
43 }
44 
set_object_tracker(JNIEnv * env,jobject thiz,const ObjectTracker * object_tracker)45 void set_object_tracker(JNIEnv* env, jobject thiz,
46                         const ObjectTracker* object_tracker) {
47   object_tracker_field.set(env, thiz,
48                            reinterpret_cast<intptr_t>(object_tracker));
49 }
50 
51 #ifdef __cplusplus
52 extern "C" {
53 #endif
54 JNIEXPORT
55 void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
56                                                jint width, jint height,
57                                                jboolean always_track);
58 
59 JNIEXPORT
60 void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
61                                                         jobject thiz);
62 
63 JNIEXPORT
64 void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
65     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
66     jfloat x2, jfloat y2, jbyteArray frame_data);
67 
68 JNIEXPORT
69 void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
70     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
71     jfloat x2, jfloat y2, jlong timestamp);
72 
73 JNIEXPORT
74 void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
75     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
76     jfloat x2, jfloat y2);
77 
78 JNIEXPORT
79 jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
80                                                    jstring object_id);
81 
82 JNIEXPORT
83 jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
84                                                         jobject thiz,
85                                                         jstring object_id);
86 
87 JNIEXPORT
88 jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
89                                                         jobject thiz,
90                                                         jstring object_id);
91 
92 JNIEXPORT
93 jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
94                                                             jobject thiz,
95                                                             jstring object_id);
96 
97 JNIEXPORT
98 jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
99                                                     jstring object_id);
100 
101 JNIEXPORT
102 void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
103     JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array);
104 
105 JNIEXPORT
106 void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
107                                                     jbyteArray y_data,
108                                                     jbyteArray uv_data,
109                                                     jlong timestamp,
110                                                     jfloatArray vg_matrix_2x3);
111 
112 JNIEXPORT
113 void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
114                                                  jstring object_id);
115 
116 JNIEXPORT
117 jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
118     JNIEnv* env, jobject thiz, jfloat scale_factor);
119 
120 JNIEXPORT
121 jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
122     JNIEnv* env, jobject thiz, jboolean only_found_);
123 
124 JNIEXPORT
125 void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
126     JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
127     jfloat position_y1, jfloat position_x2, jfloat position_y2,
128     jfloatArray delta);
129 
130 JNIEXPORT
131 void JNICALL OBJECT_TRACKER_METHOD(drawNative)(JNIEnv* env, jobject obj,
132                                                jint view_width,
133                                                jint view_height,
134                                                jfloatArray delta);
135 
136 JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
137     JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
138     jbyteArray input, jint factor, jbyteArray output);
139 
140 #ifdef __cplusplus
141 }
142 #endif
143 
144 JNIEXPORT
OBJECT_TRACKER_METHOD(initNative)145 void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
146                                                jint width, jint height,
147                                                jboolean always_track) {
148   LOGI("Initializing object tracker. %dx%d @%p", width, height, thiz);
149   const Size image_size(width, height);
150   TrackerConfig* const tracker_config = new TrackerConfig(image_size);
151   tracker_config->always_track = always_track;
152 
153   // XXX detector
154   ObjectTracker* const tracker = new ObjectTracker(tracker_config, NULL);
155   set_object_tracker(env, thiz, tracker);
156   LOGI("Initialized!");
157 
158   CHECK_ALWAYS(get_object_tracker(env, thiz) == tracker,
159                "Failure to set hand tracker!");
160 }
161 
162 JNIEXPORT
OBJECT_TRACKER_METHOD(releaseMemoryNative)163 void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
164                                                         jobject thiz) {
165   delete get_object_tracker(env, thiz);
166   set_object_tracker(env, thiz, NULL);
167 }
168 
169 JNIEXPORT
OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)170 void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
171     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
172     jfloat x2, jfloat y2, jbyteArray frame_data) {
173   const char* const id_str = env->GetStringUTFChars(object_id, 0);
174 
175   LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
176        x2, y2);
177 
178   jboolean iCopied = JNI_FALSE;
179 
180   // Copy image into currFrame.
181   jbyte* pixels = env->GetByteArrayElements(frame_data, &iCopied);
182 
183   BoundingBox bounding_box(x1, y1, x2, y2);
184   get_object_tracker(env, thiz)->RegisterNewObjectWithAppearance(
185       id_str, reinterpret_cast<const uint8_t*>(pixels), bounding_box);
186 
187   env->ReleaseByteArrayElements(frame_data, pixels, JNI_ABORT);
188 
189   env->ReleaseStringUTFChars(object_id, id_str);
190 }
191 
192 JNIEXPORT
OBJECT_TRACKER_METHOD(setPreviousPositionNative)193 void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
194     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
195     jfloat x2, jfloat y2, jlong timestamp) {
196   const char* const id_str = env->GetStringUTFChars(object_id, 0);
197 
198   LOGI(
199       "Registering the position of %s at %.2f,%.2f,%.2f,%.2f"
200       " at time %lld",
201       id_str, x1, y1, x2, y2, static_cast<int64_t>(timestamp));
202 
203   get_object_tracker(env, thiz)->SetPreviousPositionOfObject(
204       id_str, BoundingBox(x1, y1, x2, y2), timestamp);
205 
206   env->ReleaseStringUTFChars(object_id, id_str);
207 }
208 
209 JNIEXPORT
OBJECT_TRACKER_METHOD(setCurrentPositionNative)210 void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
211     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
212     jfloat x2, jfloat y2) {
213   const char* const id_str = env->GetStringUTFChars(object_id, 0);
214 
215   LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
216        x2, y2);
217 
218   get_object_tracker(env, thiz)->SetCurrentPositionOfObject(
219       id_str, BoundingBox(x1, y1, x2, y2));
220 
221   env->ReleaseStringUTFChars(object_id, id_str);
222 }
223 
224 JNIEXPORT
OBJECT_TRACKER_METHOD(haveObject)225 jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
226                                                    jstring object_id) {
227   const char* const id_str = env->GetStringUTFChars(object_id, 0);
228 
229   const bool haveObject = get_object_tracker(env, thiz)->HaveObject(id_str);
230   env->ReleaseStringUTFChars(object_id, id_str);
231   return haveObject;
232 }
233 
234 JNIEXPORT
OBJECT_TRACKER_METHOD(isObjectVisible)235 jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
236                                                         jobject thiz,
237                                                         jstring object_id) {
238   const char* const id_str = env->GetStringUTFChars(object_id, 0);
239 
240   const bool visible = get_object_tracker(env, thiz)->IsObjectVisible(id_str);
241   env->ReleaseStringUTFChars(object_id, id_str);
242   return visible;
243 }
244 
245 JNIEXPORT
OBJECT_TRACKER_METHOD(getModelIdNative)246 jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
247                                                         jobject thiz,
248                                                         jstring object_id) {
249   const char* const id_str = env->GetStringUTFChars(object_id, 0);
250   const TrackedObject* const object =
251       get_object_tracker(env, thiz)->GetObject(id_str);
252   env->ReleaseStringUTFChars(object_id, id_str);
253   jstring model_name = env->NewStringUTF(object->GetModel()->GetName().c_str());
254   return model_name;
255 }
256 
257 JNIEXPORT
OBJECT_TRACKER_METHOD(getCurrentCorrelation)258 jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
259                                                             jobject thiz,
260                                                             jstring object_id) {
261   const char* const id_str = env->GetStringUTFChars(object_id, 0);
262 
263   const float correlation =
264       get_object_tracker(env, thiz)->GetObject(id_str)->GetCorrelation();
265   env->ReleaseStringUTFChars(object_id, id_str);
266   return correlation;
267 }
268 
269 JNIEXPORT
OBJECT_TRACKER_METHOD(getMatchScore)270 jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
271                                                     jstring object_id) {
272   const char* const id_str = env->GetStringUTFChars(object_id, 0);
273 
274   const float match_score =
275       get_object_tracker(env, thiz)->GetObject(id_str)->GetMatchScore().value;
276   env->ReleaseStringUTFChars(object_id, id_str);
277   return match_score;
278 }
279 
280 JNIEXPORT
OBJECT_TRACKER_METHOD(getTrackedPositionNative)281 void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
282     JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array) {
283   jboolean iCopied = JNI_FALSE;
284   const char* const id_str = env->GetStringUTFChars(object_id, 0);
285 
286   const BoundingBox bounding_box =
287       get_object_tracker(env, thiz)->GetObject(id_str)->GetPosition();
288   env->ReleaseStringUTFChars(object_id, id_str);
289 
290   jfloat* rect = env->GetFloatArrayElements(rect_array, &iCopied);
291   bounding_box.CopyToArray(reinterpret_cast<float*>(rect));
292   env->ReleaseFloatArrayElements(rect_array, rect, 0);
293 }
294 
295 JNIEXPORT
OBJECT_TRACKER_METHOD(nextFrameNative)296 void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
297                                                     jbyteArray y_data,
298                                                     jbyteArray uv_data,
299                                                     jlong timestamp,
300                                                     jfloatArray vg_matrix_2x3) {
301   TimeLog("Starting object tracker");
302 
303   jboolean iCopied = JNI_FALSE;
304 
305   float vision_gyro_matrix_array[6];
306   jfloat* jmat = NULL;
307 
308   if (vg_matrix_2x3 != NULL) {
309     // Copy the alignment matrix into a float array.
310     jmat = env->GetFloatArrayElements(vg_matrix_2x3, &iCopied);
311     for (int i = 0; i < 6; ++i) {
312       vision_gyro_matrix_array[i] = static_cast<float>(jmat[i]);
313     }
314   }
315   // Copy image into currFrame.
316   jbyte* pixels = env->GetByteArrayElements(y_data, &iCopied);
317   jbyte* uv_pixels =
318       uv_data != NULL ? env->GetByteArrayElements(uv_data, &iCopied) : NULL;
319 
320   TimeLog("Got elements");
321 
322   // Add the frame to the object tracker object.
323   get_object_tracker(env, thiz)->NextFrame(
324       reinterpret_cast<uint8_t*>(pixels), reinterpret_cast<uint8_t*>(uv_pixels),
325       timestamp, vg_matrix_2x3 != NULL ? vision_gyro_matrix_array : NULL);
326 
327   env->ReleaseByteArrayElements(y_data, pixels, JNI_ABORT);
328 
329   if (uv_data != NULL) {
330     env->ReleaseByteArrayElements(uv_data, uv_pixels, JNI_ABORT);
331   }
332 
333   if (vg_matrix_2x3 != NULL) {
334     env->ReleaseFloatArrayElements(vg_matrix_2x3, jmat, JNI_ABORT);
335   }
336 
337   TimeLog("Released elements");
338 
339   PrintTimeLog();
340   ResetTimeLog();
341 }
342 
343 JNIEXPORT
OBJECT_TRACKER_METHOD(forgetNative)344 void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
345                                                  jstring object_id) {
346   const char* const id_str = env->GetStringUTFChars(object_id, 0);
347 
348   get_object_tracker(env, thiz)->ForgetTarget(id_str);
349 
350   env->ReleaseStringUTFChars(object_id, id_str);
351 }
352 
353 JNIEXPORT
OBJECT_TRACKER_METHOD(getKeypointsNative)354 jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
355     JNIEnv* env, jobject thiz, jboolean only_found) {
356   jfloat keypoint_arr[kMaxKeypoints * kKeypointStep];
357 
358   const int number_of_keypoints =
359       get_object_tracker(env, thiz)->GetKeypoints(only_found, keypoint_arr);
360 
361   // Create and return the array that will be passed back to Java.
362   jfloatArray keypoints =
363       env->NewFloatArray(number_of_keypoints * kKeypointStep);
364   if (keypoints == NULL) {
365     LOGE("null array!");
366     return NULL;
367   }
368   env->SetFloatArrayRegion(keypoints, 0, number_of_keypoints * kKeypointStep,
369                            keypoint_arr);
370 
371   return keypoints;
372 }
373 
374 JNIEXPORT
OBJECT_TRACKER_METHOD(getKeypointsPacked)375 jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
376     JNIEnv* env, jobject thiz, jfloat scale_factor) {
377   // 2 bytes to a uint16_t and two pairs of xy coordinates per keypoint.
378   const int bytes_per_keypoint = sizeof(uint16_t) * 2 * 2;
379   jbyte keypoint_arr[kMaxKeypoints * bytes_per_keypoint];
380 
381   const int number_of_keypoints =
382       get_object_tracker(env, thiz)->GetKeypointsPacked(
383           reinterpret_cast<uint16_t*>(keypoint_arr), scale_factor);
384 
385   // Create and return the array that will be passed back to Java.
386   jbyteArray keypoints =
387       env->NewByteArray(number_of_keypoints * bytes_per_keypoint);
388 
389   if (keypoints == NULL) {
390     LOGE("null array!");
391     return NULL;
392   }
393 
394   env->SetByteArrayRegion(
395       keypoints, 0, number_of_keypoints * bytes_per_keypoint, keypoint_arr);
396 
397   return keypoints;
398 }
399 
400 JNIEXPORT
OBJECT_TRACKER_METHOD(getCurrentPositionNative)401 void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
402     JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
403     jfloat position_y1, jfloat position_x2, jfloat position_y2,
404     jfloatArray delta) {
405   jfloat point_arr[4];
406 
407   const BoundingBox new_position = get_object_tracker(env, thiz)->TrackBox(
408       BoundingBox(position_x1, position_y1, position_x2, position_y2),
409       timestamp);
410 
411   new_position.CopyToArray(point_arr);
412   env->SetFloatArrayRegion(delta, 0, 4, point_arr);
413 }
414 
415 JNIEXPORT
OBJECT_TRACKER_METHOD(drawNative)416 void JNICALL OBJECT_TRACKER_METHOD(drawNative)(
417     JNIEnv* env, jobject thiz, jint view_width, jint view_height,
418     jfloatArray frame_to_canvas_arr) {
419   ObjectTracker* object_tracker = get_object_tracker(env, thiz);
420   if (object_tracker != NULL) {
421     jfloat* frame_to_canvas =
422         env->GetFloatArrayElements(frame_to_canvas_arr, NULL);
423 
424     object_tracker->Draw(view_width, view_height, frame_to_canvas);
425     env->ReleaseFloatArrayElements(frame_to_canvas_arr, frame_to_canvas,
426                                    JNI_ABORT);
427   }
428 }
429 
OBJECT_TRACKER_METHOD(downsampleImageNative)430 JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
431     JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
432     jbyteArray input, jint factor, jbyteArray output) {
433   if (input == NULL || output == NULL) {
434     LOGW("Received null arrays, hopefully this is a test!");
435     return;
436   }
437 
438   jbyte* const input_array = env->GetByteArrayElements(input, 0);
439   jbyte* const output_array = env->GetByteArrayElements(output, 0);
440 
441   {
442     tf_tracking::Image<uint8_t> full_image(
443         width, height, reinterpret_cast<uint8_t*>(input_array), false);
444 
445     const int new_width = (width + factor - 1) / factor;
446     const int new_height = (height + factor - 1) / factor;
447 
448     tf_tracking::Image<uint8_t> downsampled_image(
449         new_width, new_height, reinterpret_cast<uint8_t*>(output_array), false);
450 
451     downsampled_image.DownsampleAveraged(
452         reinterpret_cast<uint8_t*>(input_array), row_stride, factor);
453   }
454 
455   env->ReleaseByteArrayElements(input, input_array, JNI_ABORT);
456   env->ReleaseByteArrayElements(output, output_array, 0);
457 }
458 
459 }  // namespace tf_tracking
460