1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ 17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ 18 19 #include <map> 20 #include <string> 21 22 #include "tensorflow/examples/android/jni/object_tracking/geom.h" 23 #include "tensorflow/examples/android/jni/object_tracking/integral_image.h" 24 #include "tensorflow/examples/android/jni/object_tracking/logging.h" 25 #include "tensorflow/examples/android/jni/object_tracking/time_log.h" 26 #include "tensorflow/examples/android/jni/object_tracking/utils.h" 27 28 #include "tensorflow/examples/android/jni/object_tracking/config.h" 29 #include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" 30 #include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" 31 #include "tensorflow/examples/android/jni/object_tracking/object_model.h" 32 #include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" 33 #include "tensorflow/examples/android/jni/object_tracking/tracked_object.h" 34 35 namespace tf_tracking { 36 37 typedef std::map<const std::string, TrackedObject*> TrackedObjectMap; 38 39 inline std::ostream& operator<<(std::ostream& stream, 40 const TrackedObjectMap& map) { 41 for (TrackedObjectMap::const_iterator iter = map.begin(); 42 iter != map.end(); ++iter) { 43 const TrackedObject& tracked_object = *iter->second; 44 const std::string& key = iter->first; 45 stream << key << ": " << tracked_object; 46 } 47 return stream; 48 } 49 50 51 // ObjectTracker is the highest-level class in the tracking/detection framework. 52 // It handles basic image processing, keypoint detection, keypoint tracking, 53 // object tracking, and object detection/relocalization. 54 class ObjectTracker { 55 public: 56 ObjectTracker(const TrackerConfig* const config, 57 ObjectDetectorBase* const detector); 58 virtual ~ObjectTracker(); 59 NextFrame(const uint8_t * const new_frame,const int64_t timestamp,const float * const alignment_matrix_2x3)60 virtual void NextFrame(const uint8_t* const new_frame, 61 const int64_t timestamp, 62 const float* const alignment_matrix_2x3) { 63 NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3); 64 } 65 66 // Called upon the arrival of a new frame of raw data. 67 // Does all image processing, keypoint detection, and object 68 // tracking/detection for registered objects. 69 // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that 70 // represents the main transformation that has happened between the last 71 // and the current frame. 72 // Argument align_level is the pyramid level (where 0 == finest) that 73 // the matrix is valid for. 74 virtual void NextFrame(const uint8_t* const new_frame, 75 const uint8_t* const uv_frame, const int64_t timestamp, 76 const float* const alignment_matrix_2x3); 77 78 virtual void RegisterNewObjectWithAppearance(const std::string& id, 79 const uint8_t* const new_frame, 80 const BoundingBox& bounding_box); 81 82 // Updates the position of a tracked object, given that it was known to be at 83 // a certain position at some point in the past. 84 virtual void SetPreviousPositionOfObject(const std::string& id, 85 const BoundingBox& bounding_box, 86 const int64_t timestamp); 87 88 // Sets the current position of the object in the most recent frame provided. 89 virtual void SetCurrentPositionOfObject(const std::string& id, 90 const BoundingBox& bounding_box); 91 92 // Tells the ObjectTracker to stop tracking a target. 93 void ForgetTarget(const std::string& id); 94 95 // Fills the given out_data buffer with the latest detected keypoint 96 // correspondences, first scaled by scale_factor (to adjust for downsampling 97 // that may have occurred elsewhere), then packed in a fixed-point format. 98 int GetKeypointsPacked(uint16_t* const out_data, 99 const float scale_factor) const; 100 101 // Copy the keypoint arrays after computeFlow is called. 102 // out_data should be at least kMaxKeypoints * kKeypointStep long. 103 // Currently, its format is [x1 y1 found x2 y2 score] repeated N times, 104 // where N is the number of keypoints tracked. N is returned as the result. 105 int GetKeypoints(const bool only_found, float* const out_data) const; 106 107 // Returns the current position of a box, given that it was at a certain 108 // position at the given time. 109 BoundingBox TrackBox(const BoundingBox& region, 110 const int64_t timestamp) const; 111 112 // Returns the number of frames that have been passed to NextFrame(). GetNumFrames()113 inline int GetNumFrames() const { 114 return num_frames_; 115 } 116 HaveObject(const std::string & id)117 inline bool HaveObject(const std::string& id) const { 118 return objects_.find(id) != objects_.end(); 119 } 120 121 // Returns the TrackedObject associated with the given id. GetObject(const std::string & id)122 inline const TrackedObject* GetObject(const std::string& id) const { 123 TrackedObjectMap::const_iterator iter = objects_.find(id); 124 CHECK_ALWAYS(iter != objects_.end(), 125 "Unknown object key! \"%s\"", id.c_str()); 126 TrackedObject* const object = iter->second; 127 return object; 128 } 129 130 // Returns the TrackedObject associated with the given id. GetObject(const std::string & id)131 inline TrackedObject* GetObject(const std::string& id) { 132 TrackedObjectMap::iterator iter = objects_.find(id); 133 CHECK_ALWAYS(iter != objects_.end(), 134 "Unknown object key! \"%s\"", id.c_str()); 135 TrackedObject* const object = iter->second; 136 return object; 137 } 138 IsObjectVisible(const std::string & id)139 bool IsObjectVisible(const std::string& id) const { 140 SCHECK(HaveObject(id), "Don't have this object."); 141 142 const TrackedObject* object = GetObject(id); 143 return object->IsVisible(); 144 } 145 146 virtual void Draw(const int canvas_width, const int canvas_height, 147 const float* const frame_to_canvas) const; 148 149 protected: 150 // Creates a new tracked object at the given position. 151 // If an object model is provided, then that model will be associated with the 152 // object. If not, a new model may be created from the appearance at the 153 // initial position and registered with the object detector. 154 virtual TrackedObject* MaybeAddObject(const std::string& id, 155 const Image<uint8_t>& image, 156 const BoundingBox& bounding_box, 157 const ObjectModelBase* object_model); 158 159 // Find the keypoints in the frame before the current frame. 160 // If only one frame exists, keypoints will be found in that frame. 161 void ComputeKeypoints(const bool cached_ok = false); 162 163 // Finds the correspondences for all the points in the current pair of frames. 164 // Stores the results in the given FramePair. 165 void FindCorrespondences(FramePair* const curr_change) const; 166 GetNthIndexFromEnd(const int offset)167 inline int GetNthIndexFromEnd(const int offset) const { 168 return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset); 169 } 170 171 BoundingBox TrackBox(const BoundingBox& region, 172 const FramePair& frame_pair) const; 173 IncrementFrameIndex()174 inline void IncrementFrameIndex() { 175 // Move the current framechange index up. 176 ++num_frames_; 177 ++curr_num_frame_pairs_; 178 179 // If we've got too many, push up the start of the queue. 180 if (curr_num_frame_pairs_ > kNumFrames) { 181 first_frame_index_ = GetNthIndexFromStart(1); 182 --curr_num_frame_pairs_; 183 } 184 } 185 GetNthIndexFromStart(const int offset)186 inline int GetNthIndexFromStart(const int offset) const { 187 SCHECK(offset >= 0 && offset < curr_num_frame_pairs_, 188 "Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_); 189 return (first_frame_index_ + offset) % kNumFrames; 190 } 191 192 void TrackObjects(); 193 194 const std::unique_ptr<const TrackerConfig> config_; 195 196 const int frame_width_; 197 const int frame_height_; 198 199 int64_t curr_time_; 200 201 int num_frames_; 202 203 TrackedObjectMap objects_; 204 205 FlowCache flow_cache_; 206 207 KeypointDetector keypoint_detector_; 208 209 int curr_num_frame_pairs_; 210 int first_frame_index_; 211 212 std::unique_ptr<ImageData> frame1_; 213 std::unique_ptr<ImageData> frame2_; 214 215 FramePair frame_pairs_[kNumFrames]; 216 217 std::unique_ptr<ObjectDetectorBase> detector_; 218 219 int num_detected_; 220 221 private: 222 void TrackTarget(TrackedObject* const object); 223 224 bool GetBestObjectForDetection( 225 const Detection& detection, TrackedObject** match) const; 226 227 void ProcessDetections(std::vector<Detection>* const detections); 228 229 void DetectTargets(); 230 231 // Temp object used in ObjectTracker::CreateNewExample. 232 mutable std::vector<BoundingSquare> squares; 233 234 friend std::ostream& operator<<(std::ostream& stream, 235 const ObjectTracker& tracker); 236 237 TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker); 238 }; 239 240 inline std::ostream& operator<<(std::ostream& stream, 241 const ObjectTracker& tracker) { 242 stream << "Frame size: " << tracker.frame_width_ << "x" 243 << tracker.frame_height_ << std::endl; 244 245 stream << "Num frames: " << tracker.num_frames_ << std::endl; 246 247 stream << "Curr time: " << tracker.curr_time_ << std::endl; 248 249 const int first_frame_index = tracker.GetNthIndexFromStart(0); 250 const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index]; 251 252 const int last_frame_index = tracker.GetNthIndexFromEnd(0); 253 const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index]; 254 255 stream << "first frame: " << first_frame_index << "," 256 << first_frame_pair.end_time_ << " " 257 << "last frame: " << last_frame_index << "," 258 << last_frame_pair.end_time_ << " diff: " 259 << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms" 260 << std::endl; 261 262 stream << "Tracked targets:"; 263 stream << tracker.objects_; 264 265 return stream; 266 } 267 268 } // namespace tf_tracking 269 270 #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ 271