1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
18
19 #include "tensorflow/examples/android/jni/object_tracking/logging.h"
20 #include "tensorflow/examples/android/jni/object_tracking/utils.h"
21
22 namespace tf_tracking {
23
24 struct Size {
SizeSize25 Size(const int width, const int height) : width(width), height(height) {}
26
27 int width;
28 int height;
29 };
30
31
32 class Point2f {
33 public:
Point2f()34 Point2f() : x(0.0f), y(0.0f) {}
Point2f(const float x,const float y)35 Point2f(const float x, const float y) : x(x), y(y) {}
36
37 inline Point2f operator- (const Point2f& that) const {
38 return Point2f(this->x - that.x, this->y - that.y);
39 }
40
41 inline Point2f operator+ (const Point2f& that) const {
42 return Point2f(this->x + that.x, this->y + that.y);
43 }
44
45 inline Point2f& operator+= (const Point2f& that) {
46 this->x += that.x;
47 this->y += that.y;
48 return *this;
49 }
50
51 inline Point2f& operator-= (const Point2f& that) {
52 this->x -= that.x;
53 this->y -= that.y;
54 return *this;
55 }
56
57 inline Point2f operator- (const Point2f& that) {
58 return Point2f(this->x - that.x, this->y - that.y);
59 }
60
LengthSquared()61 inline float LengthSquared() {
62 return Square(this->x) + Square(this->y);
63 }
64
Length()65 inline float Length() {
66 return sqrtf(LengthSquared());
67 }
68
DistanceSquared(const Point2f & that)69 inline float DistanceSquared(const Point2f& that) {
70 return Square(this->x - that.x) + Square(this->y - that.y);
71 }
72
Distance(const Point2f & that)73 inline float Distance(const Point2f& that) {
74 return sqrtf(DistanceSquared(that));
75 }
76
77 float x;
78 float y;
79 };
80
81 inline std::ostream& operator<<(std::ostream& stream, const Point2f& point) {
82 stream << point.x << "," << point.y;
83 return stream;
84 }
85
86 class BoundingBox {
87 public:
BoundingBox()88 BoundingBox()
89 : left_(0),
90 top_(0),
91 right_(0),
92 bottom_(0) {}
93
BoundingBox(const BoundingBox & bounding_box)94 BoundingBox(const BoundingBox& bounding_box)
95 : left_(bounding_box.left_),
96 top_(bounding_box.top_),
97 right_(bounding_box.right_),
98 bottom_(bounding_box.bottom_) {
99 SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
100 SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
101 }
102
BoundingBox(const float left,const float top,const float right,const float bottom)103 BoundingBox(const float left,
104 const float top,
105 const float right,
106 const float bottom)
107 : left_(left),
108 top_(top),
109 right_(right),
110 bottom_(bottom) {
111 SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
112 SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
113 }
114
BoundingBox(const Point2f & point1,const Point2f & point2)115 BoundingBox(const Point2f& point1, const Point2f& point2)
116 : left_(MIN(point1.x, point2.x)),
117 top_(MIN(point1.y, point2.y)),
118 right_(MAX(point1.x, point2.x)),
119 bottom_(MAX(point1.y, point2.y)) {}
120
CopyToArray(float * const bounds_array)121 inline void CopyToArray(float* const bounds_array) const {
122 bounds_array[0] = left_;
123 bounds_array[1] = top_;
124 bounds_array[2] = right_;
125 bounds_array[3] = bottom_;
126 }
127
GetWidth()128 inline float GetWidth() const {
129 return right_ - left_;
130 }
131
GetHeight()132 inline float GetHeight() const {
133 return bottom_ - top_;
134 }
135
GetArea()136 inline float GetArea() const {
137 const float width = GetWidth();
138 const float height = GetHeight();
139 if (width <= 0 || height <= 0) {
140 return 0.0f;
141 }
142
143 return width * height;
144 }
145
GetCenter()146 inline Point2f GetCenter() const {
147 return Point2f((left_ + right_) / 2.0f,
148 (top_ + bottom_) / 2.0f);
149 }
150
ValidBox()151 inline bool ValidBox() const {
152 return GetArea() > 0.0f;
153 }
154
155 // Returns a bounding box created from the overlapping area of these two.
Intersect(const BoundingBox & that)156 inline BoundingBox Intersect(const BoundingBox& that) const {
157 const float new_left = MAX(this->left_, that.left_);
158 const float new_right = MIN(this->right_, that.right_);
159
160 if (new_left >= new_right) {
161 return BoundingBox();
162 }
163
164 const float new_top = MAX(this->top_, that.top_);
165 const float new_bottom = MIN(this->bottom_, that.bottom_);
166
167 if (new_top >= new_bottom) {
168 return BoundingBox();
169 }
170
171 return BoundingBox(new_left, new_top, new_right, new_bottom);
172 }
173
174 // Returns a bounding box that can contain both boxes.
Union(const BoundingBox & that)175 inline BoundingBox Union(const BoundingBox& that) const {
176 return BoundingBox(MIN(this->left_, that.left_),
177 MIN(this->top_, that.top_),
178 MAX(this->right_, that.right_),
179 MAX(this->bottom_, that.bottom_));
180 }
181
PascalScore(const BoundingBox & that)182 inline float PascalScore(const BoundingBox& that) const {
183 SCHECK(GetArea() > 0.0f, "Empty bounding box!");
184 SCHECK(that.GetArea() > 0.0f, "Empty bounding box!");
185
186 const float intersect_area = this->Intersect(that).GetArea();
187
188 if (intersect_area <= 0) {
189 return 0;
190 }
191
192 const float score =
193 intersect_area / (GetArea() + that.GetArea() - intersect_area);
194 SCHECK(InRange(score, 0.0f, 1.0f), "Invalid score! %.2f", score);
195 return score;
196 }
197
Intersects(const BoundingBox & that)198 inline bool Intersects(const BoundingBox& that) const {
199 return InRange(that.left_, left_, right_)
200 || InRange(that.right_, left_, right_)
201 || InRange(that.top_, top_, bottom_)
202 || InRange(that.bottom_, top_, bottom_);
203 }
204
205 // Returns whether another bounding box is completely inside of this bounding
206 // box. Sharing edges is ok.
Contains(const BoundingBox & that)207 inline bool Contains(const BoundingBox& that) const {
208 return that.left_ >= left_ &&
209 that.right_ <= right_ &&
210 that.top_ >= top_ &&
211 that.bottom_ <= bottom_;
212 }
213
Contains(const Point2f & point)214 inline bool Contains(const Point2f& point) const {
215 return InRange(point.x, left_, right_) && InRange(point.y, top_, bottom_);
216 }
217
Shift(const Point2f shift_amount)218 inline void Shift(const Point2f shift_amount) {
219 left_ += shift_amount.x;
220 top_ += shift_amount.y;
221 right_ += shift_amount.x;
222 bottom_ += shift_amount.y;
223 }
224
ScaleOrigin(const float scale_x,const float scale_y)225 inline void ScaleOrigin(const float scale_x, const float scale_y) {
226 left_ *= scale_x;
227 right_ *= scale_x;
228 top_ *= scale_y;
229 bottom_ *= scale_y;
230 }
231
Scale(const float scale_x,const float scale_y)232 inline void Scale(const float scale_x, const float scale_y) {
233 const Point2f center = GetCenter();
234 const float half_width = GetWidth() / 2.0f;
235 const float half_height = GetHeight() / 2.0f;
236
237 left_ = center.x - half_width * scale_x;
238 right_ = center.x + half_width * scale_x;
239
240 top_ = center.y - half_height * scale_y;
241 bottom_ = center.y + half_height * scale_y;
242 }
243
244 float left_;
245 float top_;
246 float right_;
247 float bottom_;
248 };
249 inline std::ostream& operator<<(std::ostream& stream, const BoundingBox& box) {
250 stream << "[" << box.left_ << " - " << box.right_
251 << ", " << box.top_ << " - " << box.bottom_
252 << ", w:" << box.GetWidth() << " h:" << box.GetHeight() << "]";
253 return stream;
254 }
255
256
257 class BoundingSquare {
258 public:
BoundingSquare(const float x,const float y,const float size)259 BoundingSquare(const float x, const float y, const float size)
260 : x_(x), y_(y), size_(size) {}
261
BoundingSquare(const BoundingBox & box)262 explicit BoundingSquare(const BoundingBox& box)
263 : x_(box.left_), y_(box.top_), size_(box.GetWidth()) {
264 #ifdef SANITY_CHECKS
265 if (std::abs(box.GetWidth() - box.GetHeight()) > 0.1f) {
266 LOG(WARNING) << "This is not a square: " << box << std::endl;
267 }
268 #endif
269 }
270
ToBoundingBox()271 inline BoundingBox ToBoundingBox() const {
272 return BoundingBox(x_, y_, x_ + size_, y_ + size_);
273 }
274
ValidBox()275 inline bool ValidBox() {
276 return size_ > 0.0f;
277 }
278
Shift(const Point2f shift_amount)279 inline void Shift(const Point2f shift_amount) {
280 x_ += shift_amount.x;
281 y_ += shift_amount.y;
282 }
283
Scale(const float scale)284 inline void Scale(const float scale) {
285 const float new_size = size_ * scale;
286 const float position_diff = (new_size - size_) / 2.0f;
287 x_ -= position_diff;
288 y_ -= position_diff;
289 size_ = new_size;
290 }
291
292 float x_;
293 float y_;
294 float size_;
295 };
296 inline std::ostream& operator<<(std::ostream& stream,
297 const BoundingSquare& square) {
298 stream << "[" << square.x_ << "," << square.y_ << " " << square.size_ << "]";
299 return stream;
300 }
301
302
GetCenteredSquare(const BoundingBox & original_box,const float size)303 inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box,
304 const float size) {
305 const float width_diff = (original_box.GetWidth() - size) / 2.0f;
306 const float height_diff = (original_box.GetHeight() - size) / 2.0f;
307 return BoundingSquare(original_box.left_ + width_diff,
308 original_box.top_ + height_diff,
309 size);
310 }
311
GetCenteredSquare(const BoundingBox & original_box)312 inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) {
313 return GetCenteredSquare(
314 original_box, MIN(original_box.GetWidth(), original_box.GetHeight()));
315 }
316
317 } // namespace tf_tracking
318
319 #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
320