1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ 17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ 18 19 #include "tensorflow/examples/android/jni/object_tracking/geom.h" 20 #include "tensorflow/examples/android/jni/object_tracking/utils.h" 21 22 #include "tensorflow/examples/android/jni/object_tracking/config.h" 23 #include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" 24 25 namespace tf_tracking { 26 27 // Class that helps OpticalFlow to speed up flow computation 28 // by caching coarse-grained flow. 29 class FlowCache { 30 public: FlowCache(const OpticalFlowConfig * const config)31 explicit FlowCache(const OpticalFlowConfig* const config) 32 : config_(config), 33 image_size_(config->image_size), 34 optical_flow_(config), 35 fullframe_matrix_(NULL) { 36 for (int i = 0; i < kNumCacheLevels; ++i) { 37 const int curr_dims = BlockDimForCacheLevel(i); 38 has_cache_[i] = new Image<bool>(curr_dims, curr_dims); 39 displacements_[i] = new Image<Point2f>(curr_dims, curr_dims); 40 } 41 } 42 ~FlowCache()43 ~FlowCache() { 44 for (int i = 0; i < kNumCacheLevels; ++i) { 45 SAFE_DELETE(has_cache_[i]); 46 SAFE_DELETE(displacements_[i]); 47 } 48 delete[](fullframe_matrix_); 49 fullframe_matrix_ = NULL; 50 } 51 NextFrame(ImageData * const new_frame,const float * const align_matrix23)52 void NextFrame(ImageData* const new_frame, 53 const float* const align_matrix23) { 54 ClearCache(); 55 SetFullframeAlignmentMatrix(align_matrix23); 56 optical_flow_.NextFrame(new_frame); 57 } 58 ClearCache()59 void ClearCache() { 60 for (int i = 0; i < kNumCacheLevels; ++i) { 61 has_cache_[i]->Clear(false); 62 } 63 delete[](fullframe_matrix_); 64 fullframe_matrix_ = NULL; 65 } 66 67 // Finds the flow at a point, using the cache for performance. FindFlowAtPoint(const float u_x,const float u_y,float * const flow_x,float * const flow_y)68 bool FindFlowAtPoint(const float u_x, const float u_y, 69 float* const flow_x, float* const flow_y) const { 70 // Get the best guess from the cache. 71 const Point2f guess_from_cache = LookupGuess(u_x, u_y); 72 73 *flow_x = guess_from_cache.x; 74 *flow_y = guess_from_cache.y; 75 76 // Now refine the guess using the image pyramid. 77 for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1; 78 pyramid_level >= 0; --pyramid_level) { 79 if (!optical_flow_.FindFlowAtPointSingleLevel( 80 pyramid_level, u_x, u_y, false, flow_x, flow_y)) { 81 return false; 82 } 83 } 84 85 return true; 86 } 87 88 // Determines the displacement of a point, and uses that to calculate a new 89 // position. 90 // Returns true iff the displacement determination worked and the new position 91 // is in the image. FindNewPositionOfPoint(const float u_x,const float u_y,float * final_x,float * final_y)92 bool FindNewPositionOfPoint(const float u_x, const float u_y, 93 float* final_x, float* final_y) const { 94 float flow_x; 95 float flow_y; 96 if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) { 97 return false; 98 } 99 100 // Add in the displacement to get the final position. 101 *final_x = u_x + flow_x; 102 *final_y = u_y + flow_y; 103 104 // Assign the best guess, if we're still in the image. 105 if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) && 106 InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) { 107 return true; 108 } else { 109 return false; 110 } 111 } 112 113 // Comparison function for qsort. Compare(const void * a,const void * b)114 static int Compare(const void* a, const void* b) { 115 return *reinterpret_cast<const float*>(a) - 116 *reinterpret_cast<const float*>(b); 117 } 118 119 // Returns the median flow within the given bounding box as determined 120 // by a grid_width x grid_height grid. GetMedianFlow(const BoundingBox & bounding_box,const bool filter_by_fb_error,const int grid_width,const int grid_height)121 Point2f GetMedianFlow(const BoundingBox& bounding_box, 122 const bool filter_by_fb_error, 123 const int grid_width, 124 const int grid_height) const { 125 const int kMaxPoints = 100; 126 SCHECK(grid_width * grid_height <= kMaxPoints, 127 "Too many points for Median flow!"); 128 129 const BoundingBox valid_box = bounding_box.Intersect( 130 BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1)); 131 132 if (valid_box.GetArea() <= 0.0f) { 133 return Point2f(0, 0); 134 } 135 136 float x_deltas[kMaxPoints]; 137 float y_deltas[kMaxPoints]; 138 139 int curr_offset = 0; 140 for (int i = 0; i < grid_width; ++i) { 141 for (int j = 0; j < grid_height; ++j) { 142 const float x_in = valid_box.left_ + 143 (valid_box.GetWidth() * i) / (grid_width - 1); 144 145 const float y_in = valid_box.top_ + 146 (valid_box.GetHeight() * j) / (grid_height - 1); 147 148 float curr_flow_x; 149 float curr_flow_y; 150 const bool success = FindNewPositionOfPoint(x_in, y_in, 151 &curr_flow_x, &curr_flow_y); 152 153 if (success) { 154 x_deltas[curr_offset] = curr_flow_x; 155 y_deltas[curr_offset] = curr_flow_y; 156 ++curr_offset; 157 } else { 158 LOGW("Tracking failure!"); 159 } 160 } 161 } 162 163 if (curr_offset > 0) { 164 qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare); 165 qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare); 166 167 return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]); 168 } 169 170 LOGW("No points were valid!"); 171 return Point2f(0, 0); 172 } 173 SetFullframeAlignmentMatrix(const float * const align_matrix23)174 void SetFullframeAlignmentMatrix(const float* const align_matrix23) { 175 if (align_matrix23 != NULL) { 176 if (fullframe_matrix_ == NULL) { 177 fullframe_matrix_ = new float[6]; 178 } 179 180 memcpy(fullframe_matrix_, align_matrix23, 181 6 * sizeof(fullframe_matrix_[0])); 182 } 183 } 184 185 private: LookupGuessFromLevel(const int cache_level,const float x,const float y)186 Point2f LookupGuessFromLevel( 187 const int cache_level, const float x, const float y) const { 188 // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level); 189 190 // Cutoff at the target level and use the matrix transform instead. 191 if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) { 192 const float xnew = x * fullframe_matrix_[0] + 193 y * fullframe_matrix_[1] + 194 fullframe_matrix_[2]; 195 const float ynew = x * fullframe_matrix_[3] + 196 y * fullframe_matrix_[4] + 197 fullframe_matrix_[5]; 198 199 return Point2f(xnew - x, ynew - y); 200 } 201 202 const int level_dim = BlockDimForCacheLevel(cache_level); 203 const int pixels_per_cache_block_x = 204 (image_size_.width + level_dim - 1) / level_dim; 205 const int pixels_per_cache_block_y = 206 (image_size_.height + level_dim - 1) / level_dim; 207 const int index_x = x / pixels_per_cache_block_x; 208 const int index_y = y / pixels_per_cache_block_y; 209 210 Point2f displacement; 211 if (!(*has_cache_[cache_level])[index_y][index_x]) { 212 (*has_cache_[cache_level])[index_y][index_x] = true; 213 214 // Get the lower cache level's best guess, if it exists. 215 displacement = cache_level >= kNumCacheLevels - 1 ? 216 Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y); 217 // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level, 218 // best_guess.x, best_guess.y); 219 220 // Find the center of the block. 221 const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x; 222 const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y; 223 const int pyramid_level = PyramidLevelForCacheLevel(cache_level); 224 225 // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] " 226 // "Querying %5.2f, %5.2f at pyramid level %d, ", 227 // cache_level, index_x, index_y, 228 // x, pixels_per_cache_block_x, y, pixels_per_cache_block_y, 229 // center_x, center_y, pyramid_level); 230 231 // TODO(andrewharp): Turn on FB error filtering. 232 const bool success = optical_flow_.FindFlowAtPointSingleLevel( 233 pyramid_level, center_x, center_y, false, 234 &displacement.x, &displacement.y); 235 236 if (!success) { 237 LOGV("Computation of cached value failed for level %d!", cache_level); 238 } 239 240 // Store the value for later use. 241 (*displacements_[cache_level])[index_y][index_x] = displacement; 242 } else { 243 displacement = (*displacements_[cache_level])[index_y][index_x]; 244 } 245 246 // LOGI("Returning %5.2f, %5.2f for level %d", 247 // displacement.x, displacement.y, cache_level); 248 return displacement; 249 } 250 LookupGuess(const float x,const float y)251 Point2f LookupGuess(const float x, const float y) const { 252 if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) { 253 return Point2f(0, 0); 254 } 255 256 // LOGI("Looking up guess at %5.2f %5.2f.", x, y); 257 if (kNumCacheLevels > 0) { 258 return LookupGuessFromLevel(0, x, y); 259 } else { 260 return Point2f(0, 0); 261 } 262 } 263 264 // Returns the number of cache bins in each dimension for a given level 265 // of the cache. BlockDimForCacheLevel(const int cache_level)266 int BlockDimForCacheLevel(const int cache_level) const { 267 // The highest (coarsest) cache level has a block dim of kCacheBranchFactor, 268 // thus if there are 4 cache levels, requesting level 3 (0-based) should 269 // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2, 270 // and so on. 271 int block_dim = kNumCacheLevels; 272 for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level; 273 --curr_level) { 274 block_dim *= kCacheBranchFactor; 275 } 276 return block_dim; 277 } 278 279 // Returns the level of the image pyramid that a given cache level maps to. PyramidLevelForCacheLevel(const int cache_level)280 int PyramidLevelForCacheLevel(const int cache_level) const { 281 // Higher cache and pyramid levels have smaller dimensions. The highest 282 // cache level should refer to the highest image pyramid level. The 283 // lower, finer image pyramid levels are uncached (assuming 284 // kNumCacheLevels < kNumPyramidLevels). 285 return cache_level + (kNumPyramidLevels - kNumCacheLevels); 286 } 287 288 const OpticalFlowConfig* const config_; 289 290 const Size image_size_; 291 OpticalFlow optical_flow_; 292 293 float* fullframe_matrix_; 294 295 // Whether this value is currently present in the cache. 296 Image<bool>* has_cache_[kNumCacheLevels]; 297 298 // The cached displacement values. 299 Image<Point2f>* displacements_[kNumCacheLevels]; 300 301 TF_DISALLOW_COPY_AND_ASSIGN(FlowCache); 302 }; 303 304 } // namespace tf_tracking 305 306 #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ 307