1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ 17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ 18 19 #ifdef __RENDER_OPENGL__ 20 #include "tensorflow/examples/android/jni/object_tracking/gl_utils.h" 21 #endif 22 #include "tensorflow/examples/android/jni/object_tracking/object_detector.h" 23 24 namespace tf_tracking { 25 26 // A TrackedObject is a specific instance of an ObjectModel, with a known 27 // position in the world. 28 // It provides the last known position and number of recent detection failures, 29 // in addition to the more general appearance data associated with the object 30 // class (which is in ObjectModel). 31 // TODO(andrewharp): Make getters/setters follow styleguide. 32 class TrackedObject { 33 public: 34 TrackedObject(const std::string& id, const Image<uint8_t>& image, 35 const BoundingBox& bounding_box, ObjectModelBase* const model); 36 37 ~TrackedObject(); 38 39 void UpdatePosition(const BoundingBox& new_position, const int64_t timestamp, 40 const ImageData& image_data, const bool authoratative); 41 42 // This method is called when the tracked object is detected at a 43 // given position, and allows the associated Model to grow and/or prune 44 // itself based on where the detection occurred. 45 void OnDetection(ObjectModelBase* const model, 46 const BoundingBox& detection_position, 47 const MatchScore match_score, const int64_t timestamp, 48 const ImageData& image_data); 49 50 // Called when there's no detection of the tracked object. This will cause 51 // a tracking failure after enough consecutive failures if the area under 52 // the current bounding box also doesn't meet a minimum correlation threshold 53 // with the model. OnDetectionFailure()54 void OnDetectionFailure() {} 55 IsVisible()56 inline bool IsVisible() const { 57 return tracked_correlation_ >= kMinimumCorrelationForTracking || 58 num_consecutive_frames_below_threshold_ < kMaxNumDetectionFailures; 59 } 60 GetCorrelation()61 inline float GetCorrelation() { 62 return tracked_correlation_; 63 } 64 GetMatchScore()65 inline MatchScore GetMatchScore() { 66 return tracked_match_score_; 67 } 68 GetPosition()69 inline BoundingBox GetPosition() const { 70 return last_known_position_; 71 } 72 GetLastDetectionPosition()73 inline BoundingBox GetLastDetectionPosition() const { 74 return last_detection_position_; 75 } 76 GetModel()77 inline const ObjectModelBase* GetModel() const { 78 return object_model_; 79 } 80 GetName()81 inline const std::string& GetName() const { 82 return id_; 83 } 84 Draw()85 inline void Draw() const { 86 #ifdef __RENDER_OPENGL__ 87 if (tracked_correlation_ < kMinimumCorrelationForTracking) { 88 glColor4f(MAX(0.0f, -tracked_correlation_), 89 MAX(0.0f, tracked_correlation_), 90 0.0f, 91 1.0f); 92 } else { 93 glColor4f(MAX(0.0f, -tracked_correlation_), 94 MAX(0.0f, tracked_correlation_), 95 1.0f, 96 1.0f); 97 } 98 99 // Render the box itself. 100 BoundingBox temp_box(last_known_position_); 101 DrawBox(temp_box); 102 103 // Render a box inside this one (in case the actual box is hidden). 104 const float kBufferSize = 1.0f; 105 temp_box.left_ -= kBufferSize; 106 temp_box.top_ -= kBufferSize; 107 temp_box.right_ += kBufferSize; 108 temp_box.bottom_ += kBufferSize; 109 DrawBox(temp_box); 110 111 // Render one outside as well. 112 temp_box.left_ -= -2.0f * kBufferSize; 113 temp_box.top_ -= -2.0f * kBufferSize; 114 temp_box.right_ += -2.0f * kBufferSize; 115 temp_box.bottom_ += -2.0f * kBufferSize; 116 DrawBox(temp_box); 117 #endif 118 } 119 120 // Get current object's num_consecutive_frames_below_threshold_. GetNumConsecutiveFramesBelowThreshold()121 inline int64_t GetNumConsecutiveFramesBelowThreshold() { 122 return num_consecutive_frames_below_threshold_; 123 } 124 125 // Reset num_consecutive_frames_below_threshold_ to 0. resetNumConsecutiveFramesBelowThreshold()126 inline void resetNumConsecutiveFramesBelowThreshold() { 127 num_consecutive_frames_below_threshold_ = 0; 128 } 129 GetAllowableDistanceSquared()130 inline float GetAllowableDistanceSquared() const { 131 return allowable_detection_distance_; 132 } 133 134 private: 135 // The unique id used throughout the system to identify this 136 // tracked object. 137 const std::string id_; 138 139 // The last known position of the object. 140 BoundingBox last_known_position_; 141 142 // The last known position of the object. 143 BoundingBox last_detection_position_; 144 145 // When the position was last computed. 146 int64_t position_last_computed_time_; 147 148 // The object model this tracked object is representative of. 149 ObjectModelBase* object_model_; 150 151 Image<float> last_detection_thumbnail_; 152 153 Image<float> last_frame_thumbnail_; 154 155 // The correlation of the object model with the preview frame at its last 156 // tracked position. 157 float tracked_correlation_; 158 159 MatchScore tracked_match_score_; 160 161 // The number of consecutive frames that the tracked position for this object 162 // has been under the correlation threshold. 163 int num_consecutive_frames_below_threshold_; 164 165 float allowable_detection_distance_; 166 167 friend std::ostream& operator<<(std::ostream& stream, 168 const TrackedObject& tracked_object); 169 170 TF_DISALLOW_COPY_AND_ASSIGN(TrackedObject); 171 }; 172 173 inline std::ostream& operator<<(std::ostream& stream, 174 const TrackedObject& tracked_object) { 175 stream << tracked_object.id_ 176 << " " << tracked_object.last_known_position_ 177 << " " << tracked_object.position_last_computed_time_ 178 << " " << tracked_object.num_consecutive_frames_below_threshold_ 179 << " " << tracked_object.object_model_ 180 << " " << tracked_object.tracked_correlation_; 181 return stream; 182 } 183 184 } // namespace tf_tracking 185 186 #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ 187