1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/tools/android/test/jni/object_tracking/frame_pair.h"
17
18 #include <float.h>
19
20 #include "tensorflow/tools/android/test/jni/object_tracking/config.h"
21
22 namespace tf_tracking {
23
Init(const int64_t start_time,const int64_t end_time)24 void FramePair::Init(const int64_t start_time, const int64_t end_time) {
25 start_time_ = start_time;
26 end_time_ = end_time;
27 memset(optical_flow_found_keypoint_, false,
28 sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints);
29 number_of_keypoints_ = 0;
30 }
31
AdjustBox(const BoundingBox box,float * const translation_x,float * const translation_y,float * const scale_x,float * const scale_y) const32 void FramePair::AdjustBox(const BoundingBox box,
33 float* const translation_x,
34 float* const translation_y,
35 float* const scale_x,
36 float* const scale_y) const {
37 static float weights[kMaxKeypoints];
38 static Point2f deltas[kMaxKeypoints];
39 memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints);
40
41 BoundingBox resized_box(box);
42 resized_box.Scale(0.4f, 0.4f);
43 FillWeights(resized_box, weights);
44 FillTranslations(deltas);
45
46 const Point2f translation = GetWeightedMedian(weights, deltas);
47
48 *translation_x = translation.x;
49 *translation_y = translation.y;
50
51 const Point2f old_center = box.GetCenter();
52 const int good_scale_points =
53 FillScales(old_center, translation, weights, deltas);
54
55 // Default scale factor is 1 for x and y.
56 *scale_x = 1.0f;
57 *scale_y = 1.0f;
58
59 // The assumption is that all deltas that make it to this stage with a
60 // corresponding optical_flow_found_keypoint_[i] == true are not in
61 // themselves degenerate.
62 //
63 // The degeneracy with scale arose because if the points are too close to the
64 // center of the objects, the scale ratio determination might be incalculable.
65 //
66 // The check for kMinNumInRange is not a degeneracy check, but merely an
67 // attempt to ensure some sort of stability. The actual degeneracy check is in
68 // the comparison to EPSILON in FillScales (which I've updated to return the
69 // number good remaining as well).
70 static const int kMinNumInRange = 5;
71 if (good_scale_points >= kMinNumInRange) {
72 const float scale_factor = GetWeightedMedianScale(weights, deltas);
73
74 if (scale_factor > 0.0f) {
75 *scale_x = scale_factor;
76 *scale_y = scale_factor;
77 }
78 }
79 }
80
FillWeights(const BoundingBox & box,float * const weights) const81 int FramePair::FillWeights(const BoundingBox& box,
82 float* const weights) const {
83 // Compute the max score.
84 float max_score = -FLT_MAX;
85 float min_score = FLT_MAX;
86 for (int i = 0; i < kMaxKeypoints; ++i) {
87 if (optical_flow_found_keypoint_[i]) {
88 max_score = MAX(max_score, frame1_keypoints_[i].score_);
89 min_score = MIN(min_score, frame1_keypoints_[i].score_);
90 }
91 }
92
93 int num_in_range = 0;
94 for (int i = 0; i < kMaxKeypoints; ++i) {
95 if (!optical_flow_found_keypoint_[i]) {
96 weights[i] = 0.0f;
97 continue;
98 }
99
100 const bool in_box = box.Contains(frame1_keypoints_[i].pos_);
101 if (in_box) {
102 ++num_in_range;
103 }
104
105 // The weighting based off distance. Anything within the bounding box
106 // has a weight of 1, and everything outside of that is within the range
107 // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio.
108 float distance_score = 1.0f;
109 if (!in_box) {
110 const Point2f initial = box.GetCenter();
111 const float sq_x_dist =
112 Square(initial.x - frame1_keypoints_[i].pos_.x);
113 const float sq_y_dist =
114 Square(initial.y - frame1_keypoints_[i].pos_.y);
115 const float squared_half_width = Square(box.GetWidth() / 2.0f);
116 const float squared_half_height = Square(box.GetHeight() / 2.0f);
117
118 static const float kOutOfBoxMultiplier = 0.5f;
119 distance_score = kOutOfBoxMultiplier *
120 MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist);
121 }
122
123 // The weighting based on relative score strength. kBaseScore - 1.0f.
124 float intrinsic_score = 1.0f;
125 if (max_score > min_score) {
126 static const float kBaseScore = 0.5f;
127 intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) /
128 (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore;
129 }
130
131 // The final score will be in the range [0, 1].
132 weights[i] = distance_score * intrinsic_score;
133 }
134
135 return num_in_range;
136 }
137
FillTranslations(Point2f * const translations) const138 void FramePair::FillTranslations(Point2f* const translations) const {
139 for (int i = 0; i < kMaxKeypoints; ++i) {
140 if (!optical_flow_found_keypoint_[i]) {
141 continue;
142 }
143 translations[i].x =
144 frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x;
145 translations[i].y =
146 frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y;
147 }
148 }
149
FillScales(const Point2f & old_center,const Point2f & translation,float * const weights,Point2f * const scales) const150 int FramePair::FillScales(const Point2f& old_center,
151 const Point2f& translation,
152 float* const weights,
153 Point2f* const scales) const {
154 int num_good = 0;
155 for (int i = 0; i < kMaxKeypoints; ++i) {
156 if (!optical_flow_found_keypoint_[i]) {
157 continue;
158 }
159
160 const Keypoint keypoint1 = frame1_keypoints_[i];
161 const Keypoint keypoint2 = frame2_keypoints_[i];
162
163 const float dist1_x = keypoint1.pos_.x - old_center.x;
164 const float dist1_y = keypoint1.pos_.y - old_center.y;
165
166 const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x;
167 const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y;
168
169 // Make sure that the scale makes sense; points too close to the center
170 // will result in either NaNs or infinite results for scale due to
171 // limited tracking and floating point resolution.
172 // Also check that the parity of the points is the same with respect to
173 // x and y, as we can't really make sense of data that has flipped.
174 if (((dist2_x > EPSILON && dist1_x > EPSILON) ||
175 (dist2_x < -EPSILON && dist1_x < -EPSILON)) &&
176 ((dist2_y > EPSILON && dist1_y > EPSILON) ||
177 (dist2_y < -EPSILON && dist1_y < -EPSILON))) {
178 scales[i].x = dist2_x / dist1_x;
179 scales[i].y = dist2_y / dist1_y;
180 ++num_good;
181 } else {
182 weights[i] = 0.0f;
183 scales[i].x = 1.0f;
184 scales[i].y = 1.0f;
185 }
186 }
187 return num_good;
188 }
189
190 struct WeightedDelta {
191 float weight;
192 float delta;
193 };
194
195 // Sort by delta, not by weight.
WeightedDeltaCompare(const void * const a,const void * const b)196 inline int WeightedDeltaCompare(const void* const a, const void* const b) {
197 return (reinterpret_cast<const WeightedDelta*>(a)->delta -
198 reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1;
199 }
200
201 // Returns the median delta from a sorted set of weighted deltas.
GetMedian(const int num_items,const WeightedDelta * const weighted_deltas,const float sum)202 static float GetMedian(const int num_items,
203 const WeightedDelta* const weighted_deltas,
204 const float sum) {
205 if (num_items == 0 || sum < EPSILON) {
206 return 0.0f;
207 }
208
209 float current_weight = 0.0f;
210 const float target_weight = sum / 2.0f;
211 for (int i = 0; i < num_items; ++i) {
212 if (weighted_deltas[i].weight > 0.0f) {
213 current_weight += weighted_deltas[i].weight;
214 if (current_weight >= target_weight) {
215 return weighted_deltas[i].delta;
216 }
217 }
218 }
219 LOGW("Median not found! %d points, sum of %.2f", num_items, sum);
220 return 0.0f;
221 }
222
GetWeightedMedian(const float * const weights,const Point2f * const deltas) const223 Point2f FramePair::GetWeightedMedian(
224 const float* const weights, const Point2f* const deltas) const {
225 Point2f median_delta;
226
227 // TODO(andrewharp): only sort deltas that could possibly have an effect.
228 static WeightedDelta weighted_deltas[kMaxKeypoints];
229
230 // Compute median X value.
231 {
232 float total_weight = 0.0f;
233
234 // Compute weighted mean and deltas.
235 for (int i = 0; i < kMaxKeypoints; ++i) {
236 weighted_deltas[i].delta = deltas[i].x;
237 const float weight = weights[i];
238 weighted_deltas[i].weight = weight;
239 if (weight > 0.0f) {
240 total_weight += weight;
241 }
242 }
243 qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
244 WeightedDeltaCompare);
245 median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
246 }
247
248 // Compute median Y value.
249 {
250 float total_weight = 0.0f;
251
252 // Compute weighted mean and deltas.
253 for (int i = 0; i < kMaxKeypoints; ++i) {
254 const float weight = weights[i];
255 weighted_deltas[i].weight = weight;
256 weighted_deltas[i].delta = deltas[i].y;
257 if (weight > 0.0f) {
258 total_weight += weight;
259 }
260 }
261 qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
262 WeightedDeltaCompare);
263 median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
264 }
265
266 return median_delta;
267 }
268
GetWeightedMedianScale(const float * const weights,const Point2f * const deltas) const269 float FramePair::GetWeightedMedianScale(
270 const float* const weights, const Point2f* const deltas) const {
271 float median_delta;
272
273 // TODO(andrewharp): only sort deltas that could possibly have an effect.
274 static WeightedDelta weighted_deltas[kMaxKeypoints * 2];
275
276 // Compute median scale value across x and y.
277 {
278 float total_weight = 0.0f;
279
280 // Add X values.
281 for (int i = 0; i < kMaxKeypoints; ++i) {
282 weighted_deltas[i].delta = deltas[i].x;
283 const float weight = weights[i];
284 weighted_deltas[i].weight = weight;
285 if (weight > 0.0f) {
286 total_weight += weight;
287 }
288 }
289
290 // Add Y values.
291 for (int i = 0; i < kMaxKeypoints; ++i) {
292 weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y;
293 const float weight = weights[i];
294 weighted_deltas[i + kMaxKeypoints].weight = weight;
295 if (weight > 0.0f) {
296 total_weight += weight;
297 }
298 }
299
300 qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta),
301 WeightedDeltaCompare);
302
303 median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight);
304 }
305
306 return median_delta;
307 }
308
309 } // namespace tf_tracking
310