1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <float.h>
17
18 #include "tensorflow/examples/android/jni/object_tracking/config.h"
19 #include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
20
21 namespace tf_tracking {
22
Init(const int64_t start_time,const int64_t end_time)23 void FramePair::Init(const int64_t start_time, const int64_t end_time) {
24 start_time_ = start_time;
25 end_time_ = end_time;
26 memset(optical_flow_found_keypoint_, false,
27 sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints);
28 number_of_keypoints_ = 0;
29 }
30
AdjustBox(const BoundingBox box,float * const translation_x,float * const translation_y,float * const scale_x,float * const scale_y) const31 void FramePair::AdjustBox(const BoundingBox box,
32 float* const translation_x,
33 float* const translation_y,
34 float* const scale_x,
35 float* const scale_y) const {
36 static float weights[kMaxKeypoints];
37 static Point2f deltas[kMaxKeypoints];
38 memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints);
39
40 BoundingBox resized_box(box);
41 resized_box.Scale(0.4f, 0.4f);
42 FillWeights(resized_box, weights);
43 FillTranslations(deltas);
44
45 const Point2f translation = GetWeightedMedian(weights, deltas);
46
47 *translation_x = translation.x;
48 *translation_y = translation.y;
49
50 const Point2f old_center = box.GetCenter();
51 const int good_scale_points =
52 FillScales(old_center, translation, weights, deltas);
53
54 // Default scale factor is 1 for x and y.
55 *scale_x = 1.0f;
56 *scale_y = 1.0f;
57
58 // The assumption is that all deltas that make it to this stage with a
59 // correspondending optical_flow_found_keypoint_[i] == true are not in
60 // themselves degenerate.
61 //
62 // The degeneracy with scale arose because if the points are too close to the
63 // center of the objects, the scale ratio determination might be incalculable.
64 //
65 // The check for kMinNumInRange is not a degeneracy check, but merely an
66 // attempt to ensure some sort of stability. The actual degeneracy check is in
67 // the comparison to EPSILON in FillScales (which I've updated to return the
68 // number good remaining as well).
69 static const int kMinNumInRange = 5;
70 if (good_scale_points >= kMinNumInRange) {
71 const float scale_factor = GetWeightedMedianScale(weights, deltas);
72
73 if (scale_factor > 0.0f) {
74 *scale_x = scale_factor;
75 *scale_y = scale_factor;
76 }
77 }
78 }
79
FillWeights(const BoundingBox & box,float * const weights) const80 int FramePair::FillWeights(const BoundingBox& box,
81 float* const weights) const {
82 // Compute the max score.
83 float max_score = -FLT_MAX;
84 float min_score = FLT_MAX;
85 for (int i = 0; i < kMaxKeypoints; ++i) {
86 if (optical_flow_found_keypoint_[i]) {
87 max_score = MAX(max_score, frame1_keypoints_[i].score_);
88 min_score = MIN(min_score, frame1_keypoints_[i].score_);
89 }
90 }
91
92 int num_in_range = 0;
93 for (int i = 0; i < kMaxKeypoints; ++i) {
94 if (!optical_flow_found_keypoint_[i]) {
95 weights[i] = 0.0f;
96 continue;
97 }
98
99 const bool in_box = box.Contains(frame1_keypoints_[i].pos_);
100 if (in_box) {
101 ++num_in_range;
102 }
103
104 // The weighting based off distance. Anything within the bounding box
105 // has a weight of 1, and everything outside of that is within the range
106 // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio.
107 float distance_score = 1.0f;
108 if (!in_box) {
109 const Point2f initial = box.GetCenter();
110 const float sq_x_dist =
111 Square(initial.x - frame1_keypoints_[i].pos_.x);
112 const float sq_y_dist =
113 Square(initial.y - frame1_keypoints_[i].pos_.y);
114 const float squared_half_width = Square(box.GetWidth() / 2.0f);
115 const float squared_half_height = Square(box.GetHeight() / 2.0f);
116
117 static const float kOutOfBoxMultiplier = 0.5f;
118 distance_score = kOutOfBoxMultiplier *
119 MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist);
120 }
121
122 // The weighting based on relative score strength. kBaseScore - 1.0f.
123 float intrinsic_score = 1.0f;
124 if (max_score > min_score) {
125 static const float kBaseScore = 0.5f;
126 intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) /
127 (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore;
128 }
129
130 // The final score will be in the range [0, 1].
131 weights[i] = distance_score * intrinsic_score;
132 }
133
134 return num_in_range;
135 }
136
FillTranslations(Point2f * const translations) const137 void FramePair::FillTranslations(Point2f* const translations) const {
138 for (int i = 0; i < kMaxKeypoints; ++i) {
139 if (!optical_flow_found_keypoint_[i]) {
140 continue;
141 }
142 translations[i].x =
143 frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x;
144 translations[i].y =
145 frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y;
146 }
147 }
148
FillScales(const Point2f & old_center,const Point2f & translation,float * const weights,Point2f * const scales) const149 int FramePair::FillScales(const Point2f& old_center,
150 const Point2f& translation,
151 float* const weights,
152 Point2f* const scales) const {
153 int num_good = 0;
154 for (int i = 0; i < kMaxKeypoints; ++i) {
155 if (!optical_flow_found_keypoint_[i]) {
156 continue;
157 }
158
159 const Keypoint keypoint1 = frame1_keypoints_[i];
160 const Keypoint keypoint2 = frame2_keypoints_[i];
161
162 const float dist1_x = keypoint1.pos_.x - old_center.x;
163 const float dist1_y = keypoint1.pos_.y - old_center.y;
164
165 const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x;
166 const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y;
167
168 // Make sure that the scale makes sense; points too close to the center
169 // will result in either NaNs or infinite results for scale due to
170 // limited tracking and floating point resolution.
171 // Also check that the parity of the points is the same with respect to
172 // x and y, as we can't really make sense of data that has flipped.
173 if (((dist2_x > EPSILON && dist1_x > EPSILON) ||
174 (dist2_x < -EPSILON && dist1_x < -EPSILON)) &&
175 ((dist2_y > EPSILON && dist1_y > EPSILON) ||
176 (dist2_y < -EPSILON && dist1_y < -EPSILON))) {
177 scales[i].x = dist2_x / dist1_x;
178 scales[i].y = dist2_y / dist1_y;
179 ++num_good;
180 } else {
181 weights[i] = 0.0f;
182 scales[i].x = 1.0f;
183 scales[i].y = 1.0f;
184 }
185 }
186 return num_good;
187 }
188
189 struct WeightedDelta {
190 float weight;
191 float delta;
192 };
193
194 // Sort by delta, not by weight.
WeightedDeltaCompare(const void * const a,const void * const b)195 inline int WeightedDeltaCompare(const void* const a, const void* const b) {
196 return (reinterpret_cast<const WeightedDelta*>(a)->delta -
197 reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1;
198 }
199
200 // Returns the median delta from a sorted set of weighted deltas.
GetMedian(const int num_items,const WeightedDelta * const weighted_deltas,const float sum)201 static float GetMedian(const int num_items,
202 const WeightedDelta* const weighted_deltas,
203 const float sum) {
204 if (num_items == 0 || sum < EPSILON) {
205 return 0.0f;
206 }
207
208 float current_weight = 0.0f;
209 const float target_weight = sum / 2.0f;
210 for (int i = 0; i < num_items; ++i) {
211 if (weighted_deltas[i].weight > 0.0f) {
212 current_weight += weighted_deltas[i].weight;
213 if (current_weight >= target_weight) {
214 return weighted_deltas[i].delta;
215 }
216 }
217 }
218 LOGW("Median not found! %d points, sum of %.2f", num_items, sum);
219 return 0.0f;
220 }
221
GetWeightedMedian(const float * const weights,const Point2f * const deltas) const222 Point2f FramePair::GetWeightedMedian(
223 const float* const weights, const Point2f* const deltas) const {
224 Point2f median_delta;
225
226 // TODO(andrewharp): only sort deltas that could possibly have an effect.
227 static WeightedDelta weighted_deltas[kMaxKeypoints];
228
229 // Compute median X value.
230 {
231 float total_weight = 0.0f;
232
233 // Compute weighted mean and deltas.
234 for (int i = 0; i < kMaxKeypoints; ++i) {
235 weighted_deltas[i].delta = deltas[i].x;
236 const float weight = weights[i];
237 weighted_deltas[i].weight = weight;
238 if (weight > 0.0f) {
239 total_weight += weight;
240 }
241 }
242 qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
243 WeightedDeltaCompare);
244 median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
245 }
246
247 // Compute median Y value.
248 {
249 float total_weight = 0.0f;
250
251 // Compute weighted mean and deltas.
252 for (int i = 0; i < kMaxKeypoints; ++i) {
253 const float weight = weights[i];
254 weighted_deltas[i].weight = weight;
255 weighted_deltas[i].delta = deltas[i].y;
256 if (weight > 0.0f) {
257 total_weight += weight;
258 }
259 }
260 qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
261 WeightedDeltaCompare);
262 median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
263 }
264
265 return median_delta;
266 }
267
GetWeightedMedianScale(const float * const weights,const Point2f * const deltas) const268 float FramePair::GetWeightedMedianScale(
269 const float* const weights, const Point2f* const deltas) const {
270 float median_delta;
271
272 // TODO(andrewharp): only sort deltas that could possibly have an effect.
273 static WeightedDelta weighted_deltas[kMaxKeypoints * 2];
274
275 // Compute median scale value across x and y.
276 {
277 float total_weight = 0.0f;
278
279 // Add X values.
280 for (int i = 0; i < kMaxKeypoints; ++i) {
281 weighted_deltas[i].delta = deltas[i].x;
282 const float weight = weights[i];
283 weighted_deltas[i].weight = weight;
284 if (weight > 0.0f) {
285 total_weight += weight;
286 }
287 }
288
289 // Add Y values.
290 for (int i = 0; i < kMaxKeypoints; ++i) {
291 weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y;
292 const float weight = weights[i];
293 weighted_deltas[i + kMaxKeypoints].weight = weight;
294 if (weight > 0.0f) {
295 total_weight += weight;
296 }
297 }
298
299 qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta),
300 WeightedDeltaCompare);
301
302 median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight);
303 }
304
305 return median_delta;
306 }
307
308 } // namespace tf_tracking
309