• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/tools/android/test/jni/object_tracking/frame_pair.h"
17 
18 #include <float.h>
19 
20 #include "tensorflow/tools/android/test/jni/object_tracking/config.h"
21 
22 namespace tf_tracking {
23 
Init(const int64_t start_time,const int64_t end_time)24 void FramePair::Init(const int64_t start_time, const int64_t end_time) {
25   start_time_ = start_time;
26   end_time_ = end_time;
27   memset(optical_flow_found_keypoint_, false,
28          sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints);
29   number_of_keypoints_ = 0;
30 }
31 
AdjustBox(const BoundingBox box,float * const translation_x,float * const translation_y,float * const scale_x,float * const scale_y) const32 void FramePair::AdjustBox(const BoundingBox box,
33                           float* const translation_x,
34                           float* const translation_y,
35                           float* const scale_x,
36                           float* const scale_y) const {
37   static float weights[kMaxKeypoints];
38   static Point2f deltas[kMaxKeypoints];
39   memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints);
40 
41   BoundingBox resized_box(box);
42   resized_box.Scale(0.4f, 0.4f);
43   FillWeights(resized_box, weights);
44   FillTranslations(deltas);
45 
46   const Point2f translation = GetWeightedMedian(weights, deltas);
47 
48   *translation_x = translation.x;
49   *translation_y = translation.y;
50 
51   const Point2f old_center = box.GetCenter();
52   const int good_scale_points =
53       FillScales(old_center, translation, weights, deltas);
54 
55   // Default scale factor is 1 for x and y.
56   *scale_x = 1.0f;
57   *scale_y = 1.0f;
58 
59   // The assumption is that all deltas that make it to this stage with a
60   // corresponding optical_flow_found_keypoint_[i] == true are not in
61   // themselves degenerate.
62   //
63   // The degeneracy with scale arose because if the points are too close to the
64   // center of the objects, the scale ratio determination might be incalculable.
65   //
66   // The check for kMinNumInRange is not a degeneracy check, but merely an
67   // attempt to ensure some sort of stability. The actual degeneracy check is in
68   // the comparison to EPSILON in FillScales (which I've updated to return the
69   // number good remaining as well).
70   static const int kMinNumInRange = 5;
71   if (good_scale_points >= kMinNumInRange) {
72     const float scale_factor = GetWeightedMedianScale(weights, deltas);
73 
74     if (scale_factor > 0.0f) {
75       *scale_x = scale_factor;
76       *scale_y = scale_factor;
77     }
78   }
79 }
80 
FillWeights(const BoundingBox & box,float * const weights) const81 int FramePair::FillWeights(const BoundingBox& box,
82                            float* const weights) const {
83   // Compute the max score.
84   float max_score = -FLT_MAX;
85   float min_score = FLT_MAX;
86   for (int i = 0; i < kMaxKeypoints; ++i) {
87     if (optical_flow_found_keypoint_[i]) {
88       max_score = MAX(max_score, frame1_keypoints_[i].score_);
89       min_score = MIN(min_score, frame1_keypoints_[i].score_);
90     }
91   }
92 
93   int num_in_range = 0;
94   for (int i = 0; i < kMaxKeypoints; ++i) {
95     if (!optical_flow_found_keypoint_[i]) {
96       weights[i] = 0.0f;
97       continue;
98     }
99 
100     const bool in_box = box.Contains(frame1_keypoints_[i].pos_);
101     if (in_box) {
102       ++num_in_range;
103     }
104 
105     // The weighting based off distance.  Anything within the bounding box
106     // has a weight of 1, and everything outside of that is within the range
107     // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio.
108     float distance_score = 1.0f;
109     if (!in_box) {
110       const Point2f initial = box.GetCenter();
111       const float sq_x_dist =
112           Square(initial.x - frame1_keypoints_[i].pos_.x);
113       const float sq_y_dist =
114           Square(initial.y - frame1_keypoints_[i].pos_.y);
115       const float squared_half_width = Square(box.GetWidth() / 2.0f);
116       const float squared_half_height = Square(box.GetHeight() / 2.0f);
117 
118       static const float kOutOfBoxMultiplier = 0.5f;
119       distance_score = kOutOfBoxMultiplier *
120           MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist);
121     }
122 
123     // The weighting based on relative score strength. kBaseScore - 1.0f.
124     float intrinsic_score =  1.0f;
125     if (max_score > min_score) {
126       static const float kBaseScore = 0.5f;
127       intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) /
128          (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore;
129     }
130 
131     // The final score will be in the range [0, 1].
132     weights[i] = distance_score * intrinsic_score;
133   }
134 
135   return num_in_range;
136 }
137 
FillTranslations(Point2f * const translations) const138 void FramePair::FillTranslations(Point2f* const translations) const {
139   for (int i = 0; i < kMaxKeypoints; ++i) {
140     if (!optical_flow_found_keypoint_[i]) {
141       continue;
142     }
143     translations[i].x =
144         frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x;
145     translations[i].y =
146         frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y;
147   }
148 }
149 
FillScales(const Point2f & old_center,const Point2f & translation,float * const weights,Point2f * const scales) const150 int FramePair::FillScales(const Point2f& old_center,
151                           const Point2f& translation,
152                           float* const weights,
153                           Point2f* const scales) const {
154   int num_good = 0;
155   for (int i = 0; i < kMaxKeypoints; ++i) {
156     if (!optical_flow_found_keypoint_[i]) {
157       continue;
158     }
159 
160     const Keypoint keypoint1 = frame1_keypoints_[i];
161     const Keypoint keypoint2 = frame2_keypoints_[i];
162 
163     const float dist1_x = keypoint1.pos_.x - old_center.x;
164     const float dist1_y = keypoint1.pos_.y - old_center.y;
165 
166     const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x;
167     const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y;
168 
169     // Make sure that the scale makes sense; points too close to the center
170     // will result in either NaNs or infinite results for scale due to
171     // limited tracking and floating point resolution.
172     // Also check that the parity of the points is the same with respect to
173     // x and y, as we can't really make sense of data that has flipped.
174     if (((dist2_x > EPSILON && dist1_x > EPSILON) ||
175          (dist2_x < -EPSILON && dist1_x < -EPSILON)) &&
176          ((dist2_y > EPSILON && dist1_y > EPSILON) ||
177           (dist2_y < -EPSILON && dist1_y < -EPSILON))) {
178       scales[i].x = dist2_x / dist1_x;
179       scales[i].y = dist2_y / dist1_y;
180       ++num_good;
181     } else {
182       weights[i] = 0.0f;
183       scales[i].x = 1.0f;
184       scales[i].y = 1.0f;
185     }
186   }
187   return num_good;
188 }
189 
190 struct WeightedDelta {
191   float weight;
192   float delta;
193 };
194 
195 // Sort by delta, not by weight.
WeightedDeltaCompare(const void * const a,const void * const b)196 inline int WeightedDeltaCompare(const void* const a, const void* const b) {
197   return (reinterpret_cast<const WeightedDelta*>(a)->delta -
198           reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1;
199 }
200 
201 // Returns the median delta from a sorted set of weighted deltas.
GetMedian(const int num_items,const WeightedDelta * const weighted_deltas,const float sum)202 static float GetMedian(const int num_items,
203                        const WeightedDelta* const weighted_deltas,
204                        const float sum) {
205   if (num_items == 0 || sum < EPSILON) {
206     return 0.0f;
207   }
208 
209   float current_weight = 0.0f;
210   const float target_weight = sum / 2.0f;
211   for (int i = 0; i < num_items; ++i) {
212     if (weighted_deltas[i].weight > 0.0f) {
213       current_weight += weighted_deltas[i].weight;
214       if (current_weight >= target_weight) {
215         return weighted_deltas[i].delta;
216       }
217     }
218   }
219   LOGW("Median not found! %d points, sum of %.2f", num_items, sum);
220   return 0.0f;
221 }
222 
GetWeightedMedian(const float * const weights,const Point2f * const deltas) const223 Point2f FramePair::GetWeightedMedian(
224     const float* const weights, const Point2f* const deltas) const {
225   Point2f median_delta;
226 
227   // TODO(andrewharp): only sort deltas that could possibly have an effect.
228   static WeightedDelta weighted_deltas[kMaxKeypoints];
229 
230   // Compute median X value.
231   {
232     float total_weight = 0.0f;
233 
234     // Compute weighted mean and deltas.
235     for (int i = 0; i < kMaxKeypoints; ++i) {
236       weighted_deltas[i].delta = deltas[i].x;
237       const float weight = weights[i];
238       weighted_deltas[i].weight = weight;
239       if (weight > 0.0f) {
240         total_weight += weight;
241       }
242     }
243     qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
244           WeightedDeltaCompare);
245     median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
246   }
247 
248   // Compute median Y value.
249   {
250     float total_weight = 0.0f;
251 
252     // Compute weighted mean and deltas.
253     for (int i = 0; i < kMaxKeypoints; ++i) {
254       const float weight = weights[i];
255       weighted_deltas[i].weight = weight;
256       weighted_deltas[i].delta = deltas[i].y;
257       if (weight > 0.0f) {
258         total_weight += weight;
259       }
260     }
261     qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
262           WeightedDeltaCompare);
263     median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
264   }
265 
266   return median_delta;
267 }
268 
GetWeightedMedianScale(const float * const weights,const Point2f * const deltas) const269 float FramePair::GetWeightedMedianScale(
270     const float* const weights, const Point2f* const deltas) const {
271   float median_delta;
272 
273   // TODO(andrewharp): only sort deltas that could possibly have an effect.
274   static WeightedDelta weighted_deltas[kMaxKeypoints * 2];
275 
276   // Compute median scale value across x and y.
277   {
278     float total_weight = 0.0f;
279 
280     // Add X values.
281     for (int i = 0; i < kMaxKeypoints; ++i) {
282       weighted_deltas[i].delta = deltas[i].x;
283       const float weight = weights[i];
284       weighted_deltas[i].weight = weight;
285       if (weight > 0.0f) {
286         total_weight += weight;
287       }
288     }
289 
290     // Add Y values.
291     for (int i = 0; i < kMaxKeypoints; ++i) {
292       weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y;
293       const float weight = weights[i];
294       weighted_deltas[i + kMaxKeypoints].weight = weight;
295       if (weight > 0.0f) {
296         total_weight += weight;
297       }
298     }
299 
300     qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta),
301           WeightedDeltaCompare);
302 
303     median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight);
304   }
305 
306   return median_delta;
307 }
308 
309 }  // namespace tf_tracking
310