1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ 17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ 18 19 #include "tensorflow/tools/android/test/jni/object_tracking/keypoint.h" 20 21 namespace tf_tracking { 22 23 // A class that records keypoint correspondences from pairs of 24 // consecutive frames. 25 class FramePair { 26 public: FramePair()27 FramePair() 28 : start_time_(0), 29 end_time_(0), 30 number_of_keypoints_(0) {} 31 32 // Cleans up the FramePair so that they can be reused. 33 void Init(const int64_t start_time, const int64_t end_time); 34 35 void AdjustBox(const BoundingBox box, 36 float* const translation_x, 37 float* const translation_y, 38 float* const scale_x, 39 float* const scale_y) const; 40 41 private: 42 // Returns the weighted median of the given deltas, computed independently on 43 // x and y. Returns 0,0 in case of failure. The assumption is that a 44 // translation of 0.0 in the degenerate case is the best that can be done, and 45 // should not be considered an error. 46 // 47 // In the case of scale, a slight exception is made just to be safe and 48 // there is a check for 0.0 explicitly, but that shouldn't ever be possible to 49 // happen naturally because of the non-zero + parity checks in FillScales. 50 Point2f GetWeightedMedian(const float* const weights, 51 const Point2f* const deltas) const; 52 53 float GetWeightedMedianScale(const float* const weights, 54 const Point2f* const deltas) const; 55 56 // Weights points based on the query_point and cutoff_dist. 57 int FillWeights(const BoundingBox& box, 58 float* const weights) const; 59 60 // Fills in the array of deltas with the translations of the points 61 // between frames. 62 void FillTranslations(Point2f* const translations) const; 63 64 // Fills in the array of deltas with the relative scale factor of points 65 // relative to a given center. Has the ability to override the weight to 0 if 66 // a degenerate scale is detected. 67 // Translation is the amount the center of the box has moved from one frame to 68 // the next. 69 int FillScales(const Point2f& old_center, 70 const Point2f& translation, 71 float* const weights, 72 Point2f* const scales) const; 73 74 // TODO(andrewharp): Make these private. 75 public: 76 // The time at frame1. 77 int64_t start_time_; 78 79 // The time at frame2. 80 int64_t end_time_; 81 82 // This array will contain the keypoints found in frame 1. 83 Keypoint frame1_keypoints_[kMaxKeypoints]; 84 85 // Contain the locations of the keypoints from frame 1 in frame 2. 86 Keypoint frame2_keypoints_[kMaxKeypoints]; 87 88 // The number of keypoints in frame 1. 89 int number_of_keypoints_; 90 91 // Keeps track of which keypoint correspondences were actually found from one 92 // frame to another. 93 // The i-th element of this array will be non-zero if and only if the i-th 94 // keypoint of frame 1 was found in frame 2. 95 bool optical_flow_found_keypoint_[kMaxKeypoints]; 96 97 private: 98 TF_DISALLOW_COPY_AND_ASSIGN(FramePair); 99 }; 100 101 } // namespace tf_tracking 102 103 #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ 104