• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
18 
19 #include <stdint.h>
20 
21 #include <vector>
22 
23 #include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h"
24 #include "tensorflow/tools/android/test/jni/object_tracking/image.h"
25 #include "tensorflow/tools/android/test/jni/object_tracking/image_data.h"
26 #include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h"
27 
28 namespace tf_tracking {
29 
30 struct Keypoint;
31 
32 class KeypointDetector {
33  public:
KeypointDetector(const KeypointDetectorConfig * const config)34   explicit KeypointDetector(const KeypointDetectorConfig* const config)
35       : config_(config),
36         keypoint_scratch_(new Image<uint8_t>(config_->image_size)),
37         interest_map_(new Image<bool>(config_->image_size)),
38         fast_quadrant_(0) {
39     interest_map_->Clear(false);
40   }
41 
~KeypointDetector()42   ~KeypointDetector() {}
43 
44   // Finds a new set of keypoints for the current frame, picked from the current
45   // set of keypoints and also from a set discovered via a keypoint detector.
46   // Special attention is applied to make sure that keypoints are distributed
47   // within the supplied ROIs.
48   void FindKeypoints(const ImageData& image_data,
49                      const std::vector<BoundingBox>& rois,
50                      const FramePair& prev_change,
51                      FramePair* const curr_change);
52 
53  private:
54   // Compute the corneriness of a point in the image.
55   float HarrisFilter(const Image<int32_t>& I_x, const Image<int32_t>& I_y,
56                      const float x, const float y) const;
57 
58   // Adds a grid of candidate keypoints to the given box, up to
59   // max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower.
60   int AddExtraCandidatesForBoxes(
61       const std::vector<BoundingBox>& boxes,
62       const int max_num_keypoints,
63       Keypoint* const keypoints) const;
64 
65   // Scan the frame for potential keypoints using the FAST keypoint detector.
66   // Quadrant is an argument 0-3 which refers to the quadrant of the image in
67   // which to detect keypoints.
68   int FindFastKeypoints(const Image<uint8_t>& frame, const int quadrant,
69                         const int downsample_factor,
70                         const int max_num_keypoints, Keypoint* const keypoints);
71 
72   int FindFastKeypoints(const ImageData& image_data,
73                         const int max_num_keypoints,
74                         Keypoint* const keypoints);
75 
76   // Score a bunch of candidate keypoints.  Assigns the scores to the input
77   // candidate_keypoints array entries.
78   void ScoreKeypoints(const ImageData& image_data,
79                       const int num_candidates,
80                       Keypoint* const candidate_keypoints);
81 
82   void SortKeypoints(const int num_candidates,
83                     Keypoint* const candidate_keypoints) const;
84 
85   // Selects a set of keypoints falling within the supplied box such that the
86   // most highly rated keypoints are picked first, and so that none of them are
87   // too close together.
88   int SelectKeypointsInBox(
89       const BoundingBox& box,
90       const Keypoint* const candidate_keypoints,
91       const int num_candidates,
92       const int max_keypoints,
93       const int num_existing_keypoints,
94       const Keypoint* const existing_keypoints,
95       Keypoint* const final_keypoints) const;
96 
97   // Selects from the supplied sorted keypoint pool a set of keypoints that will
98   // best cover the given set of boxes, such that each box is covered at a
99   // resolution proportional to its size.
100   void SelectKeypoints(
101       const std::vector<BoundingBox>& boxes,
102       const Keypoint* const candidate_keypoints,
103       const int num_candidates,
104       FramePair* const frame_change) const;
105 
106   // Copies and compacts the found keypoints in the second frame of prev_change
107   // into the array at new_keypoints.
108   static int CopyKeypoints(const FramePair& prev_change,
109                           Keypoint* const new_keypoints);
110 
111   const KeypointDetectorConfig* const config_;
112 
113   // Scratch memory for keypoint candidacy detection and non-max suppression.
114   std::unique_ptr<Image<uint8_t> > keypoint_scratch_;
115 
116   // Regions of the image to pay special attention to.
117   std::unique_ptr<Image<bool> > interest_map_;
118 
119   // The current quadrant of the image to detect FAST keypoints in.
120   // Keypoint detection is staggered for performance reasons. Every four frames
121   // a full scan of the frame will have been performed.
122   int fast_quadrant_;
123 
124   Keypoint tmp_keypoints_[kMaxTempKeypoints];
125 };
126 
127 }  // namespace tf_tracking
128 
129 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
130