• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
18 
19 #include "tensorflow/examples/android/jni/object_tracking/geom.h"
20 #include "tensorflow/examples/android/jni/object_tracking/utils.h"
21 
22 #include "tensorflow/examples/android/jni/object_tracking/config.h"
23 #include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
24 
25 namespace tf_tracking {
26 
27 // Class that helps OpticalFlow to speed up flow computation
28 // by caching coarse-grained flow.
29 class FlowCache {
30  public:
FlowCache(const OpticalFlowConfig * const config)31   explicit FlowCache(const OpticalFlowConfig* const config)
32       : config_(config),
33         image_size_(config->image_size),
34         optical_flow_(config),
35         fullframe_matrix_(NULL) {
36     for (int i = 0; i < kNumCacheLevels; ++i) {
37       const int curr_dims = BlockDimForCacheLevel(i);
38       has_cache_[i] = new Image<bool>(curr_dims, curr_dims);
39       displacements_[i] = new Image<Point2f>(curr_dims, curr_dims);
40     }
41   }
42 
~FlowCache()43   ~FlowCache() {
44     for (int i = 0; i < kNumCacheLevels; ++i) {
45       SAFE_DELETE(has_cache_[i]);
46       SAFE_DELETE(displacements_[i]);
47     }
48     delete[](fullframe_matrix_);
49     fullframe_matrix_ = NULL;
50   }
51 
NextFrame(ImageData * const new_frame,const float * const align_matrix23)52   void NextFrame(ImageData* const new_frame,
53                  const float* const align_matrix23) {
54     ClearCache();
55     SetFullframeAlignmentMatrix(align_matrix23);
56     optical_flow_.NextFrame(new_frame);
57   }
58 
ClearCache()59   void ClearCache() {
60     for (int i = 0; i < kNumCacheLevels; ++i) {
61       has_cache_[i]->Clear(false);
62     }
63     delete[](fullframe_matrix_);
64     fullframe_matrix_ = NULL;
65   }
66 
67   // Finds the flow at a point, using the cache for performance.
FindFlowAtPoint(const float u_x,const float u_y,float * const flow_x,float * const flow_y)68   bool FindFlowAtPoint(const float u_x, const float u_y,
69                        float* const flow_x, float* const flow_y) const {
70     // Get the best guess from the cache.
71     const Point2f guess_from_cache = LookupGuess(u_x, u_y);
72 
73     *flow_x = guess_from_cache.x;
74     *flow_y = guess_from_cache.y;
75 
76     // Now refine the guess using the image pyramid.
77     for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1;
78         pyramid_level >= 0; --pyramid_level) {
79       if (!optical_flow_.FindFlowAtPointSingleLevel(
80           pyramid_level, u_x, u_y, false, flow_x, flow_y)) {
81         return false;
82       }
83     }
84 
85     return true;
86   }
87 
88   // Determines the displacement of a point, and uses that to calculate a new
89   // position.
90   // Returns true iff the displacement determination worked and the new position
91   // is in the image.
FindNewPositionOfPoint(const float u_x,const float u_y,float * final_x,float * final_y)92   bool FindNewPositionOfPoint(const float u_x, const float u_y,
93                               float* final_x, float* final_y) const {
94     float flow_x;
95     float flow_y;
96     if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) {
97       return false;
98     }
99 
100     // Add in the displacement to get the final position.
101     *final_x = u_x + flow_x;
102     *final_y = u_y + flow_y;
103 
104     // Assign the best guess, if we're still in the image.
105     if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) &&
106         InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) {
107       return true;
108     } else {
109       return false;
110     }
111   }
112 
113   // Comparison function for qsort.
Compare(const void * a,const void * b)114   static int Compare(const void* a, const void* b) {
115     return *reinterpret_cast<const float*>(a) -
116            *reinterpret_cast<const float*>(b);
117   }
118 
119   // Returns the median flow within the given bounding box as determined
120   // by a grid_width x grid_height grid.
GetMedianFlow(const BoundingBox & bounding_box,const bool filter_by_fb_error,const int grid_width,const int grid_height)121   Point2f GetMedianFlow(const BoundingBox& bounding_box,
122                         const bool filter_by_fb_error,
123                         const int grid_width,
124                         const int grid_height) const {
125     const int kMaxPoints = 100;
126     SCHECK(grid_width * grid_height <= kMaxPoints,
127           "Too many points for Median flow!");
128 
129     const BoundingBox valid_box = bounding_box.Intersect(
130         BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1));
131 
132     if (valid_box.GetArea() <= 0.0f) {
133       return Point2f(0, 0);
134     }
135 
136     float x_deltas[kMaxPoints];
137     float y_deltas[kMaxPoints];
138 
139     int curr_offset = 0;
140     for (int i = 0; i < grid_width; ++i) {
141       for (int j = 0; j < grid_height; ++j) {
142         const float x_in = valid_box.left_ +
143             (valid_box.GetWidth() * i) / (grid_width - 1);
144 
145         const float y_in = valid_box.top_ +
146             (valid_box.GetHeight() * j) / (grid_height - 1);
147 
148         float curr_flow_x;
149         float curr_flow_y;
150         const bool success = FindNewPositionOfPoint(x_in, y_in,
151                                                     &curr_flow_x, &curr_flow_y);
152 
153         if (success) {
154           x_deltas[curr_offset] = curr_flow_x;
155           y_deltas[curr_offset] = curr_flow_y;
156           ++curr_offset;
157         } else {
158           LOGW("Tracking failure!");
159         }
160       }
161     }
162 
163     if (curr_offset > 0) {
164       qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare);
165       qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare);
166 
167       return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]);
168     }
169 
170     LOGW("No points were valid!");
171     return Point2f(0, 0);
172   }
173 
SetFullframeAlignmentMatrix(const float * const align_matrix23)174   void SetFullframeAlignmentMatrix(const float* const align_matrix23) {
175     if (align_matrix23 != NULL) {
176       if (fullframe_matrix_ == NULL) {
177         fullframe_matrix_ = new float[6];
178       }
179 
180       memcpy(fullframe_matrix_, align_matrix23,
181              6 * sizeof(fullframe_matrix_[0]));
182     }
183   }
184 
185  private:
LookupGuessFromLevel(const int cache_level,const float x,const float y)186   Point2f LookupGuessFromLevel(
187       const int cache_level, const float x, const float y) const {
188     // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level);
189 
190     // Cutoff at the target level and use the matrix transform instead.
191     if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) {
192       const float xnew = x * fullframe_matrix_[0] +
193                          y * fullframe_matrix_[1] +
194                              fullframe_matrix_[2];
195       const float ynew = x * fullframe_matrix_[3] +
196                          y * fullframe_matrix_[4] +
197                              fullframe_matrix_[5];
198 
199       return Point2f(xnew - x, ynew - y);
200     }
201 
202     const int level_dim = BlockDimForCacheLevel(cache_level);
203     const int pixels_per_cache_block_x =
204         (image_size_.width + level_dim - 1) / level_dim;
205     const int pixels_per_cache_block_y =
206         (image_size_.height + level_dim - 1) / level_dim;
207     const int index_x = x / pixels_per_cache_block_x;
208     const int index_y = y / pixels_per_cache_block_y;
209 
210     Point2f displacement;
211     if (!(*has_cache_[cache_level])[index_y][index_x]) {
212       (*has_cache_[cache_level])[index_y][index_x] = true;
213 
214       // Get the lower cache level's best guess, if it exists.
215       displacement = cache_level >= kNumCacheLevels - 1 ?
216           Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y);
217       // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level,
218       //      best_guess.x, best_guess.y);
219 
220       // Find the center of the block.
221       const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x;
222       const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y;
223       const int pyramid_level = PyramidLevelForCacheLevel(cache_level);
224 
225       // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] "
226       //      "Querying %5.2f, %5.2f at pyramid level %d, ",
227       //      cache_level, index_x, index_y,
228       //      x, pixels_per_cache_block_x, y, pixels_per_cache_block_y,
229       //      center_x, center_y, pyramid_level);
230 
231       // TODO(andrewharp): Turn on FB error filtering.
232       const bool success = optical_flow_.FindFlowAtPointSingleLevel(
233           pyramid_level, center_x, center_y, false,
234           &displacement.x, &displacement.y);
235 
236       if (!success) {
237         LOGV("Computation of cached value failed for level %d!", cache_level);
238       }
239 
240       // Store the value for later use.
241       (*displacements_[cache_level])[index_y][index_x] = displacement;
242     } else {
243       displacement = (*displacements_[cache_level])[index_y][index_x];
244     }
245 
246     // LOGI("Returning %5.2f, %5.2f for level %d",
247     //      displacement.x, displacement.y, cache_level);
248     return displacement;
249   }
250 
LookupGuess(const float x,const float y)251   Point2f LookupGuess(const float x, const float y) const {
252     if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) {
253       return Point2f(0, 0);
254     }
255 
256     // LOGI("Looking up guess at %5.2f %5.2f.", x, y);
257     if (kNumCacheLevels > 0) {
258       return LookupGuessFromLevel(0, x, y);
259     } else {
260       return Point2f(0, 0);
261     }
262   }
263 
264   // Returns the number of cache bins in each dimension for a given level
265   // of the cache.
BlockDimForCacheLevel(const int cache_level)266   int BlockDimForCacheLevel(const int cache_level) const {
267     // The highest (coarsest) cache level has a block dim of kCacheBranchFactor,
268     // thus if there are 4 cache levels, requesting level 3 (0-based) should
269     // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2,
270     // and so on.
271     int block_dim = kNumCacheLevels;
272     for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level;
273         --curr_level) {
274       block_dim *= kCacheBranchFactor;
275     }
276     return block_dim;
277   }
278 
279   // Returns the level of the image pyramid that a given cache level maps to.
PyramidLevelForCacheLevel(const int cache_level)280   int PyramidLevelForCacheLevel(const int cache_level) const {
281     // Higher cache and pyramid levels have smaller dimensions. The highest
282     // cache level should refer to the highest image pyramid level. The
283     // lower, finer image pyramid levels are uncached (assuming
284     // kNumCacheLevels < kNumPyramidLevels).
285     return cache_level + (kNumPyramidLevels - kNumCacheLevels);
286   }
287 
288   const OpticalFlowConfig* const config_;
289 
290   const Size image_size_;
291   OpticalFlow optical_flow_;
292 
293   float* fullframe_matrix_;
294 
295   // Whether this value is currently present in the cache.
296   Image<bool>* has_cache_[kNumCacheLevels];
297 
298   // The cached displacement values.
299   Image<Point2f>* displacements_[kNumCacheLevels];
300 
301   TF_DISALLOW_COPY_AND_ASSIGN(FlowCache);
302 };
303 
304 }  // namespace tf_tracking
305 
306 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
307