• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
18 
19 #include "tensorflow/examples/android/jni/object_tracking/logging.h"
20 #include "tensorflow/examples/android/jni/object_tracking/utils.h"
21 
22 namespace tf_tracking {
23 
24 struct Size {
SizeSize25   Size(const int width, const int height) : width(width), height(height) {}
26 
27   int width;
28   int height;
29 };
30 
31 
32 class Point2f {
33  public:
Point2f()34   Point2f() : x(0.0f), y(0.0f) {}
Point2f(const float x,const float y)35   Point2f(const float x, const float y) : x(x), y(y) {}
36 
37   inline Point2f operator- (const Point2f& that) const {
38     return Point2f(this->x - that.x, this->y - that.y);
39   }
40 
41   inline Point2f operator+ (const Point2f& that) const {
42     return Point2f(this->x + that.x, this->y + that.y);
43   }
44 
45   inline Point2f& operator+= (const Point2f& that) {
46     this->x += that.x;
47     this->y += that.y;
48     return *this;
49   }
50 
51   inline Point2f& operator-= (const Point2f& that) {
52     this->x -= that.x;
53     this->y -= that.y;
54     return *this;
55   }
56 
57   inline Point2f operator- (const Point2f& that) {
58     return Point2f(this->x - that.x, this->y - that.y);
59   }
60 
LengthSquared()61   inline float LengthSquared() {
62     return Square(this->x) + Square(this->y);
63   }
64 
Length()65   inline float Length() {
66     return sqrtf(LengthSquared());
67   }
68 
DistanceSquared(const Point2f & that)69   inline float DistanceSquared(const Point2f& that) {
70     return Square(this->x - that.x) + Square(this->y - that.y);
71   }
72 
Distance(const Point2f & that)73   inline float Distance(const Point2f& that) {
74     return sqrtf(DistanceSquared(that));
75   }
76 
77   float x;
78   float y;
79 };
80 
81 inline std::ostream& operator<<(std::ostream& stream, const Point2f& point) {
82   stream << point.x << "," << point.y;
83   return stream;
84 }
85 
86 class BoundingBox {
87  public:
BoundingBox()88   BoundingBox()
89       : left_(0),
90         top_(0),
91         right_(0),
92         bottom_(0) {}
93 
BoundingBox(const BoundingBox & bounding_box)94   BoundingBox(const BoundingBox& bounding_box)
95       : left_(bounding_box.left_),
96         top_(bounding_box.top_),
97         right_(bounding_box.right_),
98         bottom_(bounding_box.bottom_) {
99     SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
100     SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
101   }
102 
BoundingBox(const float left,const float top,const float right,const float bottom)103   BoundingBox(const float left,
104               const float top,
105               const float right,
106               const float bottom)
107       : left_(left),
108         top_(top),
109         right_(right),
110         bottom_(bottom) {
111     SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
112     SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
113   }
114 
BoundingBox(const Point2f & point1,const Point2f & point2)115   BoundingBox(const Point2f& point1, const Point2f& point2)
116       : left_(MIN(point1.x, point2.x)),
117         top_(MIN(point1.y, point2.y)),
118         right_(MAX(point1.x, point2.x)),
119         bottom_(MAX(point1.y, point2.y)) {}
120 
CopyToArray(float * const bounds_array)121   inline void CopyToArray(float* const bounds_array) const {
122     bounds_array[0] = left_;
123     bounds_array[1] = top_;
124     bounds_array[2] = right_;
125     bounds_array[3] = bottom_;
126   }
127 
GetWidth()128   inline float GetWidth() const {
129     return right_ - left_;
130   }
131 
GetHeight()132   inline float GetHeight() const {
133     return bottom_ - top_;
134   }
135 
GetArea()136   inline float GetArea() const {
137     const float width = GetWidth();
138     const float height = GetHeight();
139     if (width <= 0 || height <= 0) {
140       return 0.0f;
141     }
142 
143     return width * height;
144   }
145 
GetCenter()146   inline Point2f GetCenter() const {
147     return Point2f((left_ + right_) / 2.0f,
148                    (top_ + bottom_) / 2.0f);
149   }
150 
ValidBox()151   inline bool ValidBox() const {
152     return GetArea() > 0.0f;
153   }
154 
155   // Returns a bounding box created from the overlapping area of these two.
Intersect(const BoundingBox & that)156   inline BoundingBox Intersect(const BoundingBox& that) const {
157     const float new_left = MAX(this->left_, that.left_);
158     const float new_right = MIN(this->right_, that.right_);
159 
160     if (new_left >= new_right) {
161       return BoundingBox();
162     }
163 
164     const float new_top = MAX(this->top_, that.top_);
165     const float new_bottom = MIN(this->bottom_, that.bottom_);
166 
167     if (new_top >= new_bottom) {
168       return BoundingBox();
169     }
170 
171     return BoundingBox(new_left, new_top,  new_right, new_bottom);
172   }
173 
174   // Returns a bounding box that can contain both boxes.
Union(const BoundingBox & that)175   inline BoundingBox Union(const BoundingBox& that) const {
176     return BoundingBox(MIN(this->left_, that.left_),
177                        MIN(this->top_, that.top_),
178                        MAX(this->right_, that.right_),
179                        MAX(this->bottom_, that.bottom_));
180   }
181 
PascalScore(const BoundingBox & that)182   inline float PascalScore(const BoundingBox& that) const {
183     SCHECK(GetArea() > 0.0f, "Empty bounding box!");
184     SCHECK(that.GetArea() > 0.0f, "Empty bounding box!");
185 
186     const float intersect_area = this->Intersect(that).GetArea();
187 
188     if (intersect_area <= 0) {
189       return 0;
190     }
191 
192     const float score =
193         intersect_area / (GetArea() + that.GetArea() - intersect_area);
194     SCHECK(InRange(score, 0.0f, 1.0f), "Invalid score! %.2f", score);
195     return score;
196   }
197 
Intersects(const BoundingBox & that)198   inline bool Intersects(const BoundingBox& that) const {
199     return InRange(that.left_, left_, right_)
200         || InRange(that.right_, left_, right_)
201         || InRange(that.top_, top_, bottom_)
202         || InRange(that.bottom_, top_, bottom_);
203   }
204 
205   // Returns whether another bounding box is completely inside of this bounding
206   // box. Sharing edges is ok.
Contains(const BoundingBox & that)207   inline bool Contains(const BoundingBox& that) const {
208     return that.left_ >= left_ &&
209         that.right_ <= right_ &&
210         that.top_ >= top_ &&
211         that.bottom_ <= bottom_;
212   }
213 
Contains(const Point2f & point)214   inline bool Contains(const Point2f& point) const {
215     return InRange(point.x, left_, right_) && InRange(point.y, top_, bottom_);
216   }
217 
Shift(const Point2f shift_amount)218   inline void Shift(const Point2f shift_amount) {
219     left_ += shift_amount.x;
220     top_ += shift_amount.y;
221     right_ += shift_amount.x;
222     bottom_ += shift_amount.y;
223   }
224 
ScaleOrigin(const float scale_x,const float scale_y)225   inline void ScaleOrigin(const float scale_x, const float scale_y) {
226     left_ *= scale_x;
227     right_ *= scale_x;
228     top_ *= scale_y;
229     bottom_ *= scale_y;
230   }
231 
Scale(const float scale_x,const float scale_y)232   inline void Scale(const float scale_x, const float scale_y) {
233     const Point2f center = GetCenter();
234     const float half_width = GetWidth() / 2.0f;
235     const float half_height = GetHeight() / 2.0f;
236 
237     left_ = center.x - half_width * scale_x;
238     right_ = center.x + half_width * scale_x;
239 
240     top_ = center.y - half_height * scale_y;
241     bottom_ = center.y + half_height * scale_y;
242   }
243 
244   float left_;
245   float top_;
246   float right_;
247   float bottom_;
248 };
249 inline std::ostream& operator<<(std::ostream& stream, const BoundingBox& box) {
250   stream << "[" << box.left_ << " - " << box.right_
251          << ", " << box.top_ << " - " << box.bottom_
252          << ",  w:" << box.GetWidth() << " h:" << box.GetHeight() << "]";
253   return stream;
254 }
255 
256 
257 class BoundingSquare {
258  public:
BoundingSquare(const float x,const float y,const float size)259   BoundingSquare(const float x, const float y, const float size)
260       : x_(x), y_(y), size_(size) {}
261 
BoundingSquare(const BoundingBox & box)262   explicit BoundingSquare(const BoundingBox& box)
263       : x_(box.left_), y_(box.top_), size_(box.GetWidth()) {
264 #ifdef SANITY_CHECKS
265     if (std::abs(box.GetWidth() - box.GetHeight()) > 0.1f) {
266       LOG(WARNING) << "This is not a square: " << box << std::endl;
267     }
268 #endif
269   }
270 
ToBoundingBox()271   inline BoundingBox ToBoundingBox() const {
272     return BoundingBox(x_, y_, x_ + size_, y_ + size_);
273   }
274 
ValidBox()275   inline bool ValidBox() {
276     return size_ > 0.0f;
277   }
278 
Shift(const Point2f shift_amount)279   inline void Shift(const Point2f shift_amount) {
280     x_ += shift_amount.x;
281     y_ += shift_amount.y;
282   }
283 
Scale(const float scale)284   inline void Scale(const float scale) {
285     const float new_size = size_ * scale;
286     const float position_diff = (new_size - size_) / 2.0f;
287     x_ -= position_diff;
288     y_ -= position_diff;
289     size_ = new_size;
290   }
291 
292   float x_;
293   float y_;
294   float size_;
295 };
296 inline std::ostream& operator<<(std::ostream& stream,
297                                 const BoundingSquare& square) {
298   stream << "[" << square.x_ << "," << square.y_ << " " << square.size_ << "]";
299   return stream;
300 }
301 
302 
GetCenteredSquare(const BoundingBox & original_box,const float size)303 inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box,
304                                         const float size) {
305   const float width_diff = (original_box.GetWidth() - size) / 2.0f;
306   const float height_diff = (original_box.GetHeight() - size) / 2.0f;
307   return BoundingSquare(original_box.left_ + width_diff,
308                         original_box.top_ + height_diff,
309                         size);
310 }
311 
GetCenteredSquare(const BoundingBox & original_box)312 inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) {
313   return GetCenteredSquare(
314       original_box, MIN(original_box.GetWidth(), original_box.GetHeight()));
315 }
316 
317 }  // namespace tf_tracking
318 
319 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
320