• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
18 
19 #ifdef __RENDER_OPENGL__
20 #include "tensorflow/examples/android/jni/object_tracking/gl_utils.h"
21 #endif
22 #include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
23 
24 namespace tf_tracking {
25 
26 // A TrackedObject is a specific instance of an ObjectModel, with a known
27 // position in the world.
28 // It provides the last known position and number of recent detection failures,
29 // in addition to the more general appearance data associated with the object
30 // class (which is in ObjectModel).
31 // TODO(andrewharp): Make getters/setters follow styleguide.
32 class TrackedObject {
33  public:
34   TrackedObject(const std::string& id, const Image<uint8_t>& image,
35                 const BoundingBox& bounding_box, ObjectModelBase* const model);
36 
37   ~TrackedObject();
38 
39   void UpdatePosition(const BoundingBox& new_position, const int64_t timestamp,
40                       const ImageData& image_data, const bool authoratative);
41 
42   // This method is called when the tracked object is detected at a
43   // given position, and allows the associated Model to grow and/or prune
44   // itself based on where the detection occurred.
45   void OnDetection(ObjectModelBase* const model,
46                    const BoundingBox& detection_position,
47                    const MatchScore match_score, const int64_t timestamp,
48                    const ImageData& image_data);
49 
50   // Called when there's no detection of the tracked object. This will cause
51   // a tracking failure after enough consecutive failures if the area under
52   // the current bounding box also doesn't meet a minimum correlation threshold
53   // with the model.
OnDetectionFailure()54   void OnDetectionFailure() {}
55 
IsVisible()56   inline bool IsVisible() const {
57     return tracked_correlation_ >= kMinimumCorrelationForTracking ||
58         num_consecutive_frames_below_threshold_ < kMaxNumDetectionFailures;
59   }
60 
GetCorrelation()61   inline float GetCorrelation() {
62     return tracked_correlation_;
63   }
64 
GetMatchScore()65   inline MatchScore GetMatchScore() {
66     return tracked_match_score_;
67   }
68 
GetPosition()69   inline BoundingBox GetPosition() const {
70     return last_known_position_;
71   }
72 
GetLastDetectionPosition()73   inline BoundingBox GetLastDetectionPosition() const {
74     return last_detection_position_;
75   }
76 
GetModel()77   inline const ObjectModelBase* GetModel() const {
78     return object_model_;
79   }
80 
GetName()81   inline const std::string& GetName() const {
82     return id_;
83   }
84 
Draw()85   inline void Draw() const {
86 #ifdef __RENDER_OPENGL__
87     if (tracked_correlation_ < kMinimumCorrelationForTracking) {
88       glColor4f(MAX(0.0f, -tracked_correlation_),
89                 MAX(0.0f, tracked_correlation_),
90                 0.0f,
91                 1.0f);
92     } else {
93       glColor4f(MAX(0.0f, -tracked_correlation_),
94                 MAX(0.0f, tracked_correlation_),
95                 1.0f,
96                 1.0f);
97     }
98 
99     // Render the box itself.
100     BoundingBox temp_box(last_known_position_);
101     DrawBox(temp_box);
102 
103     // Render a box inside this one (in case the actual box is hidden).
104     const float kBufferSize = 1.0f;
105     temp_box.left_ -= kBufferSize;
106     temp_box.top_ -= kBufferSize;
107     temp_box.right_ += kBufferSize;
108     temp_box.bottom_ += kBufferSize;
109     DrawBox(temp_box);
110 
111     // Render one outside as well.
112     temp_box.left_ -= -2.0f * kBufferSize;
113     temp_box.top_ -= -2.0f * kBufferSize;
114     temp_box.right_ += -2.0f * kBufferSize;
115     temp_box.bottom_ += -2.0f * kBufferSize;
116     DrawBox(temp_box);
117 #endif
118   }
119 
120   // Get current object's num_consecutive_frames_below_threshold_.
GetNumConsecutiveFramesBelowThreshold()121   inline int64_t GetNumConsecutiveFramesBelowThreshold() {
122     return num_consecutive_frames_below_threshold_;
123   }
124 
125   // Reset num_consecutive_frames_below_threshold_ to 0.
resetNumConsecutiveFramesBelowThreshold()126   inline void resetNumConsecutiveFramesBelowThreshold() {
127     num_consecutive_frames_below_threshold_ = 0;
128   }
129 
GetAllowableDistanceSquared()130   inline float GetAllowableDistanceSquared() const {
131     return allowable_detection_distance_;
132   }
133 
134  private:
135   // The unique id used throughout the system to identify this
136   // tracked object.
137   const std::string id_;
138 
139   // The last known position of the object.
140   BoundingBox last_known_position_;
141 
142   // The last known position of the object.
143   BoundingBox last_detection_position_;
144 
145   // When the position was last computed.
146   int64_t position_last_computed_time_;
147 
148   // The object model this tracked object is representative of.
149   ObjectModelBase* object_model_;
150 
151   Image<float> last_detection_thumbnail_;
152 
153   Image<float> last_frame_thumbnail_;
154 
155   // The correlation of the object model with the preview frame at its last
156   // tracked position.
157   float tracked_correlation_;
158 
159   MatchScore tracked_match_score_;
160 
161   // The number of consecutive frames that the tracked position for this object
162   // has been under the correlation threshold.
163   int num_consecutive_frames_below_threshold_;
164 
165   float allowable_detection_distance_;
166 
167   friend std::ostream& operator<<(std::ostream& stream,
168                                   const TrackedObject& tracked_object);
169 
170   TF_DISALLOW_COPY_AND_ASSIGN(TrackedObject);
171 };
172 
173 inline std::ostream& operator<<(std::ostream& stream,
174                                 const TrackedObject& tracked_object) {
175   stream << tracked_object.id_
176       << " " << tracked_object.last_known_position_
177       << " " << tracked_object.position_last_computed_time_
178       << " " << tracked_object.num_consecutive_frames_below_threshold_
179       << " " << tracked_object.object_model_
180       << " " << tracked_object.tracked_correlation_;
181   return stream;
182 }
183 
184 }  // namespace tf_tracking
185 
186 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
187