1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ 17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ 18 19 #include <stdint.h> 20 21 #include <vector> 22 23 #include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" 24 #include "tensorflow/tools/android/test/jni/object_tracking/image.h" 25 #include "tensorflow/tools/android/test/jni/object_tracking/image_data.h" 26 #include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h" 27 28 namespace tf_tracking { 29 30 struct Keypoint; 31 32 class KeypointDetector { 33 public: KeypointDetector(const KeypointDetectorConfig * const config)34 explicit KeypointDetector(const KeypointDetectorConfig* const config) 35 : config_(config), 36 keypoint_scratch_(new Image<uint8_t>(config_->image_size)), 37 interest_map_(new Image<bool>(config_->image_size)), 38 fast_quadrant_(0) { 39 interest_map_->Clear(false); 40 } 41 ~KeypointDetector()42 ~KeypointDetector() {} 43 44 // Finds a new set of keypoints for the current frame, picked from the current 45 // set of keypoints and also from a set discovered via a keypoint detector. 46 // Special attention is applied to make sure that keypoints are distributed 47 // within the supplied ROIs. 48 void FindKeypoints(const ImageData& image_data, 49 const std::vector<BoundingBox>& rois, 50 const FramePair& prev_change, 51 FramePair* const curr_change); 52 53 private: 54 // Compute the corneriness of a point in the image. 55 float HarrisFilter(const Image<int32_t>& I_x, const Image<int32_t>& I_y, 56 const float x, const float y) const; 57 58 // Adds a grid of candidate keypoints to the given box, up to 59 // max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower. 60 int AddExtraCandidatesForBoxes( 61 const std::vector<BoundingBox>& boxes, 62 const int max_num_keypoints, 63 Keypoint* const keypoints) const; 64 65 // Scan the frame for potential keypoints using the FAST keypoint detector. 66 // Quadrant is an argument 0-3 which refers to the quadrant of the image in 67 // which to detect keypoints. 68 int FindFastKeypoints(const Image<uint8_t>& frame, const int quadrant, 69 const int downsample_factor, 70 const int max_num_keypoints, Keypoint* const keypoints); 71 72 int FindFastKeypoints(const ImageData& image_data, 73 const int max_num_keypoints, 74 Keypoint* const keypoints); 75 76 // Score a bunch of candidate keypoints. Assigns the scores to the input 77 // candidate_keypoints array entries. 78 void ScoreKeypoints(const ImageData& image_data, 79 const int num_candidates, 80 Keypoint* const candidate_keypoints); 81 82 void SortKeypoints(const int num_candidates, 83 Keypoint* const candidate_keypoints) const; 84 85 // Selects a set of keypoints falling within the supplied box such that the 86 // most highly rated keypoints are picked first, and so that none of them are 87 // too close together. 88 int SelectKeypointsInBox( 89 const BoundingBox& box, 90 const Keypoint* const candidate_keypoints, 91 const int num_candidates, 92 const int max_keypoints, 93 const int num_existing_keypoints, 94 const Keypoint* const existing_keypoints, 95 Keypoint* const final_keypoints) const; 96 97 // Selects from the supplied sorted keypoint pool a set of keypoints that will 98 // best cover the given set of boxes, such that each box is covered at a 99 // resolution proportional to its size. 100 void SelectKeypoints( 101 const std::vector<BoundingBox>& boxes, 102 const Keypoint* const candidate_keypoints, 103 const int num_candidates, 104 FramePair* const frame_change) const; 105 106 // Copies and compacts the found keypoints in the second frame of prev_change 107 // into the array at new_keypoints. 108 static int CopyKeypoints(const FramePair& prev_change, 109 Keypoint* const new_keypoints); 110 111 const KeypointDetectorConfig* const config_; 112 113 // Scratch memory for keypoint candidacy detection and non-max suppression. 114 std::unique_ptr<Image<uint8_t> > keypoint_scratch_; 115 116 // Regions of the image to pay special attention to. 117 std::unique_ptr<Image<bool> > interest_map_; 118 119 // The current quadrant of the image to detect FAST keypoints in. 120 // Keypoint detection is staggered for performance reasons. Every four frames 121 // a full scan of the frame will have been performed. 122 int fast_quadrant_; 123 124 Keypoint tmp_keypoints_[kMaxTempKeypoints]; 125 }; 126 127 } // namespace tf_tracking 128 129 #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ 130