• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Various keypoint detecting functions.
17 
18 #include <float.h>
19 
20 #include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
21 #include "tensorflow/examples/android/jni/object_tracking/image.h"
22 #include "tensorflow/examples/android/jni/object_tracking/time_log.h"
23 #include "tensorflow/examples/android/jni/object_tracking/utils.h"
24 
25 #include "tensorflow/examples/android/jni/object_tracking/config.h"
26 #include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
27 #include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
28 
29 namespace tf_tracking {
30 
GetDistSquaredBetween(const int * vec1,const int * vec2)31 static inline int GetDistSquaredBetween(const int* vec1, const int* vec2) {
32   return Square(vec1[0] - vec2[0]) + Square(vec1[1] - vec2[1]);
33 }
34 
ScoreKeypoints(const ImageData & image_data,const int num_candidates,Keypoint * const candidate_keypoints)35 void KeypointDetector::ScoreKeypoints(const ImageData& image_data,
36                                       const int num_candidates,
37                                       Keypoint* const candidate_keypoints) {
38   const Image<int>& I_x = *image_data.GetSpatialX(0);
39   const Image<int>& I_y = *image_data.GetSpatialY(0);
40 
41   if (config_->detect_skin) {
42     const Image<uint8_t>& u_data = *image_data.GetU();
43     const Image<uint8_t>& v_data = *image_data.GetV();
44 
45     static const int reference[] = {111, 155};
46 
47     // Score all the keypoints.
48     for (int i = 0; i < num_candidates; ++i) {
49       Keypoint* const keypoint = candidate_keypoints + i;
50 
51       const int x_pos = keypoint->pos_.x * 2;
52       const int y_pos = keypoint->pos_.y * 2;
53 
54       const int curr_color[] = {u_data[y_pos][x_pos], v_data[y_pos][x_pos]};
55       keypoint->score_ =
56           HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y) /
57           GetDistSquaredBetween(reference, curr_color);
58     }
59   } else {
60     // Score all the keypoints.
61     for (int i = 0; i < num_candidates; ++i) {
62       Keypoint* const keypoint = candidate_keypoints + i;
63       keypoint->score_ =
64           HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y);
65     }
66   }
67 }
68 
69 
KeypointCompare(const void * const a,const void * const b)70 inline int KeypointCompare(const void* const a, const void* const b) {
71   return (reinterpret_cast<const Keypoint*>(a)->score_ -
72           reinterpret_cast<const Keypoint*>(b)->score_) <= 0 ? 1 : -1;
73 }
74 
75 
76 // Quicksorts detected keypoints by score.
SortKeypoints(const int num_candidates,Keypoint * const candidate_keypoints) const77 void KeypointDetector::SortKeypoints(const int num_candidates,
78                                    Keypoint* const candidate_keypoints) const {
79   qsort(candidate_keypoints, num_candidates, sizeof(Keypoint), KeypointCompare);
80 
81 #ifdef SANITY_CHECKS
82   // Verify that the array got sorted.
83   float last_score = FLT_MAX;
84   for (int i = 0; i < num_candidates; ++i) {
85     const float curr_score = candidate_keypoints[i].score_;
86 
87     // Scores should be monotonically increasing.
88     SCHECK(last_score >= curr_score,
89           "Quicksort failure! %d: %.5f > %d: %.5f (%d total)",
90           i - 1, last_score, i, curr_score, num_candidates);
91 
92     last_score = curr_score;
93   }
94 #endif
95 }
96 
97 
SelectKeypointsInBox(const BoundingBox & box,const Keypoint * const candidate_keypoints,const int num_candidates,const int max_keypoints,const int num_existing_keypoints,const Keypoint * const existing_keypoints,Keypoint * const final_keypoints) const98 int KeypointDetector::SelectKeypointsInBox(
99     const BoundingBox& box,
100     const Keypoint* const candidate_keypoints,
101     const int num_candidates,
102     const int max_keypoints,
103     const int num_existing_keypoints,
104     const Keypoint* const existing_keypoints,
105     Keypoint* const final_keypoints) const {
106   if (max_keypoints <= 0) {
107     return 0;
108   }
109 
110   // This is the distance within which keypoints may be placed to each other
111   // within this box, roughly based on the box dimensions.
112   const int distance =
113       MAX(1, MIN(box.GetWidth(), box.GetHeight()) * kClosestPercent / 2.0f);
114 
115   // First, mark keypoints that already happen to be inside this region. Ignore
116   // keypoints that are outside it, however close they might be.
117   interest_map_->Clear(false);
118   for (int i = 0; i < num_existing_keypoints; ++i) {
119     const Keypoint& candidate = existing_keypoints[i];
120 
121     const int x_pos = candidate.pos_.x;
122     const int y_pos = candidate.pos_.y;
123     if (box.Contains(candidate.pos_)) {
124       MarkImage(x_pos, y_pos, distance, interest_map_.get());
125     }
126   }
127 
128   // Now, go through and check which keypoints will still fit in the box.
129   int num_keypoints_selected = 0;
130   for (int i = 0; i < num_candidates; ++i) {
131     const Keypoint& candidate = candidate_keypoints[i];
132 
133     const int x_pos = candidate.pos_.x;
134     const int y_pos = candidate.pos_.y;
135 
136     if (!box.Contains(candidate.pos_) ||
137         !interest_map_->ValidPixel(x_pos, y_pos)) {
138       continue;
139     }
140 
141     if (!(*interest_map_)[y_pos][x_pos]) {
142       final_keypoints[num_keypoints_selected++] = candidate;
143       if (num_keypoints_selected >= max_keypoints) {
144         break;
145       }
146       MarkImage(x_pos, y_pos, distance, interest_map_.get());
147     }
148   }
149   return num_keypoints_selected;
150 }
151 
152 
SelectKeypoints(const std::vector<BoundingBox> & boxes,const Keypoint * const candidate_keypoints,const int num_candidates,FramePair * const curr_change) const153 void KeypointDetector::SelectKeypoints(
154     const std::vector<BoundingBox>& boxes,
155     const Keypoint* const candidate_keypoints,
156     const int num_candidates,
157     FramePair* const curr_change) const {
158   // Now select all the interesting keypoints that fall insider our boxes.
159   curr_change->number_of_keypoints_ = 0;
160   for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
161       iter != boxes.end(); ++iter) {
162     const BoundingBox bounding_box = *iter;
163 
164     // Count up keypoints that have already been selected, and fall within our
165     // box.
166     int num_keypoints_already_in_box = 0;
167     for (int i = 0; i < curr_change->number_of_keypoints_; ++i) {
168       if (bounding_box.Contains(curr_change->frame1_keypoints_[i].pos_)) {
169         ++num_keypoints_already_in_box;
170       }
171     }
172 
173     const int max_keypoints_to_find_in_box =
174         MIN(kMaxKeypointsForObject - num_keypoints_already_in_box,
175             kMaxKeypoints - curr_change->number_of_keypoints_);
176 
177     const int num_new_keypoints_in_box = SelectKeypointsInBox(
178         bounding_box,
179         candidate_keypoints,
180         num_candidates,
181         max_keypoints_to_find_in_box,
182         curr_change->number_of_keypoints_,
183         curr_change->frame1_keypoints_,
184         curr_change->frame1_keypoints_ + curr_change->number_of_keypoints_);
185 
186     curr_change->number_of_keypoints_ += num_new_keypoints_in_box;
187 
188     LOGV("Selected %d keypoints!", curr_change->number_of_keypoints_);
189   }
190 }
191 
192 
193 // Walks along the given circle checking for pixels above or below the center.
194 // Returns a score, or 0 if the keypoint did not pass the criteria.
195 //
196 // Parameters:
197 //  circle_perimeter: the circumference in pixels of the circle.
198 //  threshold: the minimum number of contiguous pixels that must be above or
199 //             below the center value.
200 //  center_ptr: the location of the center pixel in memory
201 //  offsets: the relative offsets from the center pixel of the edge pixels.
TestCircle(const int circle_perimeter,const int threshold,const uint8_t * const center_ptr,const int * offsets)202 inline int TestCircle(const int circle_perimeter, const int threshold,
203                       const uint8_t* const center_ptr, const int* offsets) {
204   // Get the actual value of the center pixel for easier reference later on.
205   const int center_value = static_cast<int>(*center_ptr);
206 
207   // Number of total pixels to check.  Have to wrap around some in case
208   // the contiguous section is split by the array edges.
209   const int num_total = circle_perimeter + threshold - 1;
210 
211   int num_above = 0;
212   int above_diff = 0;
213 
214   int num_below = 0;
215   int below_diff = 0;
216 
217   // Used to tell when this is definitely not going to meet the threshold so we
218   // can early abort.
219   int minimum_by_now = threshold - num_total + 1;
220 
221   // Go through every pixel along the perimeter of the circle, and then around
222   // again a little bit.
223   for (int i = 0; i < num_total; ++i) {
224     // This should be faster than mod.
225     const int perim_index = i < circle_perimeter ? i : i - circle_perimeter;
226 
227     // This gets the value of the current pixel along the perimeter by using
228     // a precomputed offset.
229     const int curr_value =
230         static_cast<int>(center_ptr[offsets[perim_index]]);
231 
232     const int difference = curr_value - center_value;
233 
234     if (difference > kFastDiffAmount) {
235       above_diff += difference;
236       ++num_above;
237 
238       num_below = 0;
239       below_diff = 0;
240 
241       if (num_above >= threshold) {
242         return above_diff;
243       }
244     } else if (difference < -kFastDiffAmount) {
245       below_diff += difference;
246       ++num_below;
247 
248       num_above = 0;
249       above_diff = 0;
250 
251       if (num_below >= threshold) {
252         return below_diff;
253       }
254     } else {
255       num_above = 0;
256       num_below = 0;
257       above_diff = 0;
258       below_diff = 0;
259     }
260 
261     // See if there's any chance of making the threshold.
262     if (MAX(num_above, num_below) < minimum_by_now) {
263       // Didn't pass.
264       return 0;
265     }
266     ++minimum_by_now;
267   }
268 
269   // Didn't pass.
270   return 0;
271 }
272 
273 
274 // Returns a score in the range [0.0, positive infinity) which represents the
275 // relative likelihood of a point being a corner.
HarrisFilter(const Image<int32_t> & I_x,const Image<int32_t> & I_y,const float x,const float y) const276 float KeypointDetector::HarrisFilter(const Image<int32_t>& I_x,
277                                      const Image<int32_t>& I_y, const float x,
278                                      const float y) const {
279   if (I_x.ValidInterpPixel(x - kHarrisWindowSize, y - kHarrisWindowSize) &&
280       I_x.ValidInterpPixel(x + kHarrisWindowSize, y + kHarrisWindowSize)) {
281     // Image gradient matrix.
282     float G[] = { 0, 0, 0, 0 };
283     CalculateG(kHarrisWindowSize, x, y, I_x, I_y, G);
284 
285     const float dx = G[0];
286     const float dy = G[3];
287     const float dxy = G[1];
288 
289     // Harris-Nobel corner score.
290     return (dx * dy - Square(dxy)) / (dx + dy + FLT_MIN);
291   }
292 
293   return 0.0f;
294 }
295 
296 
AddExtraCandidatesForBoxes(const std::vector<BoundingBox> & boxes,const int max_num_keypoints,Keypoint * const keypoints) const297 int KeypointDetector::AddExtraCandidatesForBoxes(
298     const std::vector<BoundingBox>& boxes,
299     const int max_num_keypoints,
300     Keypoint* const keypoints) const {
301   int num_keypoints_added = 0;
302 
303   for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
304       iter != boxes.end(); ++iter) {
305     const BoundingBox box = *iter;
306 
307     for (int i = 0; i < kNumToAddAsCandidates; ++i) {
308       for (int j = 0; j < kNumToAddAsCandidates; ++j) {
309         if (num_keypoints_added >= max_num_keypoints) {
310           LOGW("Hit cap of %d for temporary keypoints!", max_num_keypoints);
311           return num_keypoints_added;
312         }
313 
314         Keypoint& curr_keypoint = keypoints[num_keypoints_added++];
315         curr_keypoint.pos_ = Point2f(
316             box.left_ + box.GetWidth() * (i + 0.5f) / kNumToAddAsCandidates,
317             box.top_ + box.GetHeight() * (j + 0.5f) / kNumToAddAsCandidates);
318         curr_keypoint.type_ = KEYPOINT_TYPE_INTEREST;
319       }
320     }
321   }
322 
323   return num_keypoints_added;
324 }
325 
326 
FindKeypoints(const ImageData & image_data,const std::vector<BoundingBox> & rois,const FramePair & prev_change,FramePair * const curr_change)327 void KeypointDetector::FindKeypoints(const ImageData& image_data,
328                                    const std::vector<BoundingBox>& rois,
329                                    const FramePair& prev_change,
330                                    FramePair* const curr_change) {
331   // Copy keypoints from second frame of last pass to temp keypoints of this
332   // pass.
333   int number_of_tmp_keypoints = CopyKeypoints(prev_change, tmp_keypoints_);
334 
335   const int max_num_fast = kMaxTempKeypoints - number_of_tmp_keypoints;
336   number_of_tmp_keypoints +=
337       FindFastKeypoints(image_data, max_num_fast,
338                        tmp_keypoints_ + number_of_tmp_keypoints);
339 
340   TimeLog("Found FAST keypoints");
341 
342   if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
343     LOGW("Hit cap of %d for temporary keypoints (FAST)! %d keypoints",
344          kMaxTempKeypoints, number_of_tmp_keypoints);
345   }
346 
347   if (kAddArbitraryKeypoints) {
348     // Add some for each object prior to scoring.
349     const int max_num_box_keypoints =
350         kMaxTempKeypoints - number_of_tmp_keypoints;
351     number_of_tmp_keypoints +=
352         AddExtraCandidatesForBoxes(rois, max_num_box_keypoints,
353                                    tmp_keypoints_ + number_of_tmp_keypoints);
354     TimeLog("Added box keypoints");
355 
356     if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
357       LOGW("Hit cap of %d for temporary keypoints (boxes)! %d keypoints",
358            kMaxTempKeypoints, number_of_tmp_keypoints);
359     }
360   }
361 
362   // Score them...
363   LOGV("Scoring %d keypoints!", number_of_tmp_keypoints);
364   ScoreKeypoints(image_data, number_of_tmp_keypoints, tmp_keypoints_);
365   TimeLog("Scored keypoints");
366 
367   // Now pare it down a bit.
368   SortKeypoints(number_of_tmp_keypoints, tmp_keypoints_);
369   TimeLog("Sorted keypoints");
370 
371   LOGV("%d keypoints to select from!", number_of_tmp_keypoints);
372 
373   SelectKeypoints(rois, tmp_keypoints_, number_of_tmp_keypoints, curr_change);
374   TimeLog("Selected keypoints");
375 
376   LOGV("Picked %d (%d max) final keypoints out of %d potential.",
377        curr_change->number_of_keypoints_,
378        kMaxKeypoints, number_of_tmp_keypoints);
379 }
380 
381 
CopyKeypoints(const FramePair & prev_change,Keypoint * const new_keypoints)382 int KeypointDetector::CopyKeypoints(const FramePair& prev_change,
383                                   Keypoint* const new_keypoints) {
384   int number_of_keypoints = 0;
385 
386   // Caching values from last pass, just copy and compact.
387   for (int i = 0; i < prev_change.number_of_keypoints_; ++i) {
388     if (prev_change.optical_flow_found_keypoint_[i]) {
389       new_keypoints[number_of_keypoints] =
390           prev_change.frame2_keypoints_[i];
391 
392       new_keypoints[number_of_keypoints].score_ =
393           prev_change.frame1_keypoints_[i].score_;
394 
395       ++number_of_keypoints;
396     }
397   }
398 
399   TimeLog("Copied keypoints");
400   return number_of_keypoints;
401 }
402 
403 
404 // FAST keypoint detector.
FindFastKeypoints(const Image<uint8_t> & frame,const int quadrant,const int downsample_factor,const int max_num_keypoints,Keypoint * const keypoints)405 int KeypointDetector::FindFastKeypoints(const Image<uint8_t>& frame,
406                                         const int quadrant,
407                                         const int downsample_factor,
408                                         const int max_num_keypoints,
409                                         Keypoint* const keypoints) {
410   /*
411    // Reference for a circle of diameter 7.
412    const int circle[] = {0, 0, 1, 1, 1, 0, 0,
413                          0, 1, 0, 0, 0, 1, 0,
414                          1, 0, 0, 0, 0, 0, 1,
415                          1, 0, 0, 0, 0, 0, 1,
416                          1, 0, 0, 0, 0, 0, 1,
417                          0, 1, 0, 0, 0, 1, 0,
418                          0, 0, 1, 1, 1, 0, 0};
419    const int circle_offset[] =
420        {2, 3, 4, 8, 12, 14, 20, 21, 27, 28, 34, 36, 40, 44, 45, 46};
421    */
422 
423   // Quick test of compass directions.  Any length 16 circle with a break of up
424   // to 4 pixels will have at least 3 of these 4 pixels active.
425   static const int short_circle_perimeter = 4;
426   static const int short_threshold = 3;
427   static const int short_circle_x[] = { -3,  0, +3,  0 };
428   static const int short_circle_y[] = {  0, -3,  0, +3 };
429 
430   // Precompute image offsets.
431   int short_offsets[short_circle_perimeter];
432   for (int i = 0; i < short_circle_perimeter; ++i) {
433     short_offsets[i] = short_circle_x[i] + short_circle_y[i] * frame.GetWidth();
434   }
435 
436   // Large circle values.
437   static const int full_circle_perimeter = 16;
438   static const int full_threshold = 12;
439   static const int full_circle_x[] =
440       { -1,  0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2, -3, -3, -3, -2 };
441   static const int full_circle_y[] =
442       { -3, -3, -3, -2, -1,  0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2 };
443 
444   // Precompute image offsets.
445   int full_offsets[full_circle_perimeter];
446   for (int i = 0; i < full_circle_perimeter; ++i) {
447     full_offsets[i] = full_circle_x[i] + full_circle_y[i] * frame.GetWidth();
448   }
449 
450   const int scratch_stride = frame.stride();
451 
452   keypoint_scratch_->Clear(0);
453 
454   // Set up the bounds on the region to test based on the passed-in quadrant.
455   const int quadrant_width = (frame.GetWidth() / 2) - kFastBorderBuffer;
456   const int quadrant_height = (frame.GetHeight() / 2) - kFastBorderBuffer;
457   const int start_x =
458       kFastBorderBuffer + ((quadrant % 2 == 0) ? 0 : quadrant_width);
459   const int start_y =
460       kFastBorderBuffer + ((quadrant < 2) ? 0 : quadrant_height);
461   const int end_x = start_x + quadrant_width;
462   const int end_y = start_y + quadrant_height;
463 
464   // Loop through once to find FAST keypoint clumps.
465   for (int img_y = start_y; img_y < end_y; ++img_y) {
466     const uint8_t* curr_pixel_ptr = frame[img_y] + start_x;
467 
468     for (int img_x = start_x; img_x < end_x; ++img_x) {
469       // Only insert it if it meets the quick minimum requirements test.
470       if (TestCircle(short_circle_perimeter, short_threshold,
471                      curr_pixel_ptr, short_offsets) != 0) {
472         // Longer test for actual keypoint score..
473         const int fast_score = TestCircle(full_circle_perimeter,
474                                           full_threshold,
475                                           curr_pixel_ptr,
476                                           full_offsets);
477 
478         // Non-zero score means the keypoint was found.
479         if (fast_score != 0) {
480           uint8_t* const center_ptr = (*keypoint_scratch_)[img_y] + img_x;
481 
482           // Increase the keypoint count on this pixel and the pixels in all
483           // 4 cardinal directions.
484           *center_ptr += 5;
485           *(center_ptr - 1) += 1;
486           *(center_ptr + 1) += 1;
487           *(center_ptr - scratch_stride) += 1;
488           *(center_ptr + scratch_stride) += 1;
489         }
490       }
491 
492       ++curr_pixel_ptr;
493     }  // x
494   }  // y
495 
496   TimeLog("Found FAST keypoints.");
497 
498   int num_keypoints = 0;
499   // Loop through again and Harris filter pixels in the center of clumps.
500   // We can shrink the window by 1 pixel on every side.
501   for (int img_y = start_y + 1; img_y < end_y - 1; ++img_y) {
502     const uint8_t* curr_pixel_ptr = (*keypoint_scratch_)[img_y] + start_x;
503 
504     for (int img_x = start_x + 1; img_x < end_x - 1; ++img_x) {
505       if (*curr_pixel_ptr >= kMinNumConnectedForFastKeypoint) {
506        Keypoint* const keypoint = keypoints + num_keypoints;
507         keypoint->pos_ = Point2f(
508             img_x * downsample_factor, img_y * downsample_factor);
509         keypoint->score_ = 0;
510         keypoint->type_ = KEYPOINT_TYPE_FAST;
511 
512         ++num_keypoints;
513         if (num_keypoints >= max_num_keypoints) {
514           return num_keypoints;
515         }
516       }
517 
518       ++curr_pixel_ptr;
519     }  // x
520   }  // y
521 
522   TimeLog("Picked FAST keypoints.");
523 
524   return num_keypoints;
525 }
526 
FindFastKeypoints(const ImageData & image_data,const int max_num_keypoints,Keypoint * const keypoints)527 int KeypointDetector::FindFastKeypoints(const ImageData& image_data,
528                                         const int max_num_keypoints,
529                                         Keypoint* const keypoints) {
530   int downsample_factor = 1;
531   int num_found = 0;
532 
533   // TODO(andrewharp): Get this working for multiple image scales.
534   for (int i = 0; i < 1; ++i) {
535     const Image<uint8_t>& frame = *image_data.GetPyramidSqrt2Level(i);
536     num_found += FindFastKeypoints(
537         frame, fast_quadrant_,
538         downsample_factor, max_num_keypoints, keypoints + num_found);
539     downsample_factor *= 2;
540   }
541 
542   // Increment the current quadrant.
543   fast_quadrant_ = (fast_quadrant_ + 1) % 4;
544 
545   return num_found;
546 }
547 
548 }  // namespace tf_tracking
549