1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ 17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ 18 19 #include <stdint.h> 20 #include <vector> 21 22 #include "tensorflow/examples/android/jni/object_tracking/image-inl.h" 23 #include "tensorflow/examples/android/jni/object_tracking/image.h" 24 #include "tensorflow/examples/android/jni/object_tracking/image_data.h" 25 #include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" 26 27 namespace tf_tracking { 28 29 struct Keypoint; 30 31 class KeypointDetector { 32 public: KeypointDetector(const KeypointDetectorConfig * const config)33 explicit KeypointDetector(const KeypointDetectorConfig* const config) 34 : config_(config), 35 keypoint_scratch_(new Image<uint8_t>(config_->image_size)), 36 interest_map_(new Image<bool>(config_->image_size)), 37 fast_quadrant_(0) { 38 interest_map_->Clear(false); 39 } 40 ~KeypointDetector()41 ~KeypointDetector() {} 42 43 // Finds a new set of keypoints for the current frame, picked from the current 44 // set of keypoints and also from a set discovered via a keypoint detector. 45 // Special attention is applied to make sure that keypoints are distributed 46 // within the supplied ROIs. 47 void FindKeypoints(const ImageData& image_data, 48 const std::vector<BoundingBox>& rois, 49 const FramePair& prev_change, 50 FramePair* const curr_change); 51 52 private: 53 // Compute the corneriness of a point in the image. 54 float HarrisFilter(const Image<int32_t>& I_x, const Image<int32_t>& I_y, 55 const float x, const float y) const; 56 57 // Adds a grid of candidate keypoints to the given box, up to 58 // max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower. 59 int AddExtraCandidatesForBoxes( 60 const std::vector<BoundingBox>& boxes, 61 const int max_num_keypoints, 62 Keypoint* const keypoints) const; 63 64 // Scan the frame for potential keypoints using the FAST keypoint detector. 65 // Quadrant is an argument 0-3 which refers to the quadrant of the image in 66 // which to detect keypoints. 67 int FindFastKeypoints(const Image<uint8_t>& frame, const int quadrant, 68 const int downsample_factor, 69 const int max_num_keypoints, Keypoint* const keypoints); 70 71 int FindFastKeypoints(const ImageData& image_data, 72 const int max_num_keypoints, 73 Keypoint* const keypoints); 74 75 // Score a bunch of candidate keypoints. Assigns the scores to the input 76 // candidate_keypoints array entries. 77 void ScoreKeypoints(const ImageData& image_data, 78 const int num_candidates, 79 Keypoint* const candidate_keypoints); 80 81 void SortKeypoints(const int num_candidates, 82 Keypoint* const candidate_keypoints) const; 83 84 // Selects a set of keypoints falling within the supplied box such that the 85 // most highly rated keypoints are picked first, and so that none of them are 86 // too close together. 87 int SelectKeypointsInBox( 88 const BoundingBox& box, 89 const Keypoint* const candidate_keypoints, 90 const int num_candidates, 91 const int max_keypoints, 92 const int num_existing_keypoints, 93 const Keypoint* const existing_keypoints, 94 Keypoint* const final_keypoints) const; 95 96 // Selects from the supplied sorted keypoint pool a set of keypoints that will 97 // best cover the given set of boxes, such that each box is covered at a 98 // resolution proportional to its size. 99 void SelectKeypoints( 100 const std::vector<BoundingBox>& boxes, 101 const Keypoint* const candidate_keypoints, 102 const int num_candidates, 103 FramePair* const frame_change) const; 104 105 // Copies and compacts the found keypoints in the second frame of prev_change 106 // into the array at new_keypoints. 107 static int CopyKeypoints(const FramePair& prev_change, 108 Keypoint* const new_keypoints); 109 110 const KeypointDetectorConfig* const config_; 111 112 // Scratch memory for keypoint candidacy detection and non-max suppression. 113 std::unique_ptr<Image<uint8_t> > keypoint_scratch_; 114 115 // Regions of the image to pay special attention to. 116 std::unique_ptr<Image<bool> > interest_map_; 117 118 // The current quadrant of the image to detect FAST keypoints in. 119 // Keypoint detection is staggered for performance reasons. Every four frames 120 // a full scan of the frame will have been performed. 121 int fast_quadrant_; 122 123 Keypoint tmp_keypoints_[kMaxTempKeypoints]; 124 }; 125 126 } // namespace tf_tracking 127 128 #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ 129