• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <float.h>
17 
18 #include "tensorflow/examples/android/jni/object_tracking/config.h"
19 #include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
20 
21 namespace tf_tracking {
22 
Init(const int64_t start_time,const int64_t end_time)23 void FramePair::Init(const int64_t start_time, const int64_t end_time) {
24   start_time_ = start_time;
25   end_time_ = end_time;
26   memset(optical_flow_found_keypoint_, false,
27          sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints);
28   number_of_keypoints_ = 0;
29 }
30 
AdjustBox(const BoundingBox box,float * const translation_x,float * const translation_y,float * const scale_x,float * const scale_y) const31 void FramePair::AdjustBox(const BoundingBox box,
32                           float* const translation_x,
33                           float* const translation_y,
34                           float* const scale_x,
35                           float* const scale_y) const {
36   static float weights[kMaxKeypoints];
37   static Point2f deltas[kMaxKeypoints];
38   memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints);
39 
40   BoundingBox resized_box(box);
41   resized_box.Scale(0.4f, 0.4f);
42   FillWeights(resized_box, weights);
43   FillTranslations(deltas);
44 
45   const Point2f translation = GetWeightedMedian(weights, deltas);
46 
47   *translation_x = translation.x;
48   *translation_y = translation.y;
49 
50   const Point2f old_center = box.GetCenter();
51   const int good_scale_points =
52       FillScales(old_center, translation, weights, deltas);
53 
54   // Default scale factor is 1 for x and y.
55   *scale_x = 1.0f;
56   *scale_y = 1.0f;
57 
58   // The assumption is that all deltas that make it to this stage with a
59   // correspondending optical_flow_found_keypoint_[i] == true are not in
60   // themselves degenerate.
61   //
62   // The degeneracy with scale arose because if the points are too close to the
63   // center of the objects, the scale ratio determination might be incalculable.
64   //
65   // The check for kMinNumInRange is not a degeneracy check, but merely an
66   // attempt to ensure some sort of stability. The actual degeneracy check is in
67   // the comparison to EPSILON in FillScales (which I've updated to return the
68   // number good remaining as well).
69   static const int kMinNumInRange = 5;
70   if (good_scale_points >= kMinNumInRange) {
71     const float scale_factor = GetWeightedMedianScale(weights, deltas);
72 
73     if (scale_factor > 0.0f) {
74       *scale_x = scale_factor;
75       *scale_y = scale_factor;
76     }
77   }
78 }
79 
FillWeights(const BoundingBox & box,float * const weights) const80 int FramePair::FillWeights(const BoundingBox& box,
81                            float* const weights) const {
82   // Compute the max score.
83   float max_score = -FLT_MAX;
84   float min_score = FLT_MAX;
85   for (int i = 0; i < kMaxKeypoints; ++i) {
86     if (optical_flow_found_keypoint_[i]) {
87       max_score = MAX(max_score, frame1_keypoints_[i].score_);
88       min_score = MIN(min_score, frame1_keypoints_[i].score_);
89     }
90   }
91 
92   int num_in_range = 0;
93   for (int i = 0; i < kMaxKeypoints; ++i) {
94     if (!optical_flow_found_keypoint_[i]) {
95       weights[i] = 0.0f;
96       continue;
97     }
98 
99     const bool in_box = box.Contains(frame1_keypoints_[i].pos_);
100     if (in_box) {
101       ++num_in_range;
102     }
103 
104     // The weighting based off distance.  Anything within the bounding box
105     // has a weight of 1, and everything outside of that is within the range
106     // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio.
107     float distance_score = 1.0f;
108     if (!in_box) {
109       const Point2f initial = box.GetCenter();
110       const float sq_x_dist =
111           Square(initial.x - frame1_keypoints_[i].pos_.x);
112       const float sq_y_dist =
113           Square(initial.y - frame1_keypoints_[i].pos_.y);
114       const float squared_half_width = Square(box.GetWidth() / 2.0f);
115       const float squared_half_height = Square(box.GetHeight() / 2.0f);
116 
117       static const float kOutOfBoxMultiplier = 0.5f;
118       distance_score = kOutOfBoxMultiplier *
119           MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist);
120     }
121 
122     // The weighting based on relative score strength. kBaseScore - 1.0f.
123     float intrinsic_score =  1.0f;
124     if (max_score > min_score) {
125       static const float kBaseScore = 0.5f;
126       intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) /
127          (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore;
128     }
129 
130     // The final score will be in the range [0, 1].
131     weights[i] = distance_score * intrinsic_score;
132   }
133 
134   return num_in_range;
135 }
136 
FillTranslations(Point2f * const translations) const137 void FramePair::FillTranslations(Point2f* const translations) const {
138   for (int i = 0; i < kMaxKeypoints; ++i) {
139     if (!optical_flow_found_keypoint_[i]) {
140       continue;
141     }
142     translations[i].x =
143         frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x;
144     translations[i].y =
145         frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y;
146   }
147 }
148 
FillScales(const Point2f & old_center,const Point2f & translation,float * const weights,Point2f * const scales) const149 int FramePair::FillScales(const Point2f& old_center,
150                           const Point2f& translation,
151                           float* const weights,
152                           Point2f* const scales) const {
153   int num_good = 0;
154   for (int i = 0; i < kMaxKeypoints; ++i) {
155     if (!optical_flow_found_keypoint_[i]) {
156       continue;
157     }
158 
159     const Keypoint keypoint1 = frame1_keypoints_[i];
160     const Keypoint keypoint2 = frame2_keypoints_[i];
161 
162     const float dist1_x = keypoint1.pos_.x - old_center.x;
163     const float dist1_y = keypoint1.pos_.y - old_center.y;
164 
165     const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x;
166     const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y;
167 
168     // Make sure that the scale makes sense; points too close to the center
169     // will result in either NaNs or infinite results for scale due to
170     // limited tracking and floating point resolution.
171     // Also check that the parity of the points is the same with respect to
172     // x and y, as we can't really make sense of data that has flipped.
173     if (((dist2_x > EPSILON && dist1_x > EPSILON) ||
174          (dist2_x < -EPSILON && dist1_x < -EPSILON)) &&
175          ((dist2_y > EPSILON && dist1_y > EPSILON) ||
176           (dist2_y < -EPSILON && dist1_y < -EPSILON))) {
177       scales[i].x = dist2_x / dist1_x;
178       scales[i].y = dist2_y / dist1_y;
179       ++num_good;
180     } else {
181       weights[i] = 0.0f;
182       scales[i].x = 1.0f;
183       scales[i].y = 1.0f;
184     }
185   }
186   return num_good;
187 }
188 
189 struct WeightedDelta {
190   float weight;
191   float delta;
192 };
193 
194 // Sort by delta, not by weight.
WeightedDeltaCompare(const void * const a,const void * const b)195 inline int WeightedDeltaCompare(const void* const a, const void* const b) {
196   return (reinterpret_cast<const WeightedDelta*>(a)->delta -
197           reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1;
198 }
199 
200 // Returns the median delta from a sorted set of weighted deltas.
GetMedian(const int num_items,const WeightedDelta * const weighted_deltas,const float sum)201 static float GetMedian(const int num_items,
202                        const WeightedDelta* const weighted_deltas,
203                        const float sum) {
204   if (num_items == 0 || sum < EPSILON) {
205     return 0.0f;
206   }
207 
208   float current_weight = 0.0f;
209   const float target_weight = sum / 2.0f;
210   for (int i = 0; i < num_items; ++i) {
211     if (weighted_deltas[i].weight > 0.0f) {
212       current_weight += weighted_deltas[i].weight;
213       if (current_weight >= target_weight) {
214         return weighted_deltas[i].delta;
215       }
216     }
217   }
218   LOGW("Median not found! %d points, sum of %.2f", num_items, sum);
219   return 0.0f;
220 }
221 
GetWeightedMedian(const float * const weights,const Point2f * const deltas) const222 Point2f FramePair::GetWeightedMedian(
223     const float* const weights, const Point2f* const deltas) const {
224   Point2f median_delta;
225 
226   // TODO(andrewharp): only sort deltas that could possibly have an effect.
227   static WeightedDelta weighted_deltas[kMaxKeypoints];
228 
229   // Compute median X value.
230   {
231     float total_weight = 0.0f;
232 
233     // Compute weighted mean and deltas.
234     for (int i = 0; i < kMaxKeypoints; ++i) {
235       weighted_deltas[i].delta = deltas[i].x;
236       const float weight = weights[i];
237       weighted_deltas[i].weight = weight;
238       if (weight > 0.0f) {
239         total_weight += weight;
240       }
241     }
242     qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
243           WeightedDeltaCompare);
244     median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
245   }
246 
247   // Compute median Y value.
248   {
249     float total_weight = 0.0f;
250 
251     // Compute weighted mean and deltas.
252     for (int i = 0; i < kMaxKeypoints; ++i) {
253       const float weight = weights[i];
254       weighted_deltas[i].weight = weight;
255       weighted_deltas[i].delta = deltas[i].y;
256       if (weight > 0.0f) {
257         total_weight += weight;
258       }
259     }
260     qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
261           WeightedDeltaCompare);
262     median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
263   }
264 
265   return median_delta;
266 }
267 
GetWeightedMedianScale(const float * const weights,const Point2f * const deltas) const268 float FramePair::GetWeightedMedianScale(
269     const float* const weights, const Point2f* const deltas) const {
270   float median_delta;
271 
272   // TODO(andrewharp): only sort deltas that could possibly have an effect.
273   static WeightedDelta weighted_deltas[kMaxKeypoints * 2];
274 
275   // Compute median scale value across x and y.
276   {
277     float total_weight = 0.0f;
278 
279     // Add X values.
280     for (int i = 0; i < kMaxKeypoints; ++i) {
281       weighted_deltas[i].delta = deltas[i].x;
282       const float weight = weights[i];
283       weighted_deltas[i].weight = weight;
284       if (weight > 0.0f) {
285         total_weight += weight;
286       }
287     }
288 
289     // Add Y values.
290     for (int i = 0; i < kMaxKeypoints; ++i) {
291       weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y;
292       const float weight = weights[i];
293       weighted_deltas[i + kMaxKeypoints].weight = weight;
294       if (weight > 0.0f) {
295         total_weight += weight;
296       }
297     }
298 
299     qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta),
300           WeightedDeltaCompare);
301 
302     median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight);
303   }
304 
305   return median_delta;
306 }
307 
308 }  // namespace tf_tracking
309