1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ 17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ 18 19 #include <stdint.h> 20 21 #include <memory> 22 23 #include "tensorflow/tools/android/test/jni/object_tracking/config.h" 24 #include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" 25 #include "tensorflow/tools/android/test/jni/object_tracking/image.h" 26 #include "tensorflow/tools/android/test/jni/object_tracking/image_utils.h" 27 #include "tensorflow/tools/android/test/jni/object_tracking/integral_image.h" 28 #include "tensorflow/tools/android/test/jni/object_tracking/time_log.h" 29 #include "tensorflow/tools/android/test/jni/object_tracking/utils.h" 30 31 namespace tf_tracking { 32 33 // Class that encapsulates all bulky processed data for a frame. 34 class ImageData { 35 public: ImageData(const int width,const int height)36 explicit ImageData(const int width, const int height) 37 : uv_frame_width_(width << 1), 38 uv_frame_height_(height << 1), 39 timestamp_(0), 40 image_(width, height) { 41 InitPyramid(width, height); 42 ResetComputationCache(); 43 } 44 45 private: ResetComputationCache()46 void ResetComputationCache() { 47 uv_data_computed_ = false; 48 integral_image_computed_ = false; 49 for (int i = 0; i < kNumPyramidLevels; ++i) { 50 spatial_x_computed_[i] = false; 51 spatial_y_computed_[i] = false; 52 pyramid_sqrt2_computed_[i * 2] = false; 53 pyramid_sqrt2_computed_[i * 2 + 1] = false; 54 } 55 } 56 InitPyramid(const int width,const int height)57 void InitPyramid(const int width, const int height) { 58 int level_width = width; 59 int level_height = height; 60 61 for (int i = 0; i < kNumPyramidLevels; ++i) { 62 pyramid_sqrt2_[i * 2] = NULL; 63 pyramid_sqrt2_[i * 2 + 1] = NULL; 64 spatial_x_[i] = NULL; 65 spatial_y_[i] = NULL; 66 67 level_width /= 2; 68 level_height /= 2; 69 } 70 71 // Alias the first pyramid level to image_. 72 pyramid_sqrt2_[0] = &image_; 73 } 74 75 public: ~ImageData()76 ~ImageData() { 77 // The first pyramid level is actually an alias to image_, 78 // so make sure it doesn't get deleted here. 79 pyramid_sqrt2_[0] = NULL; 80 81 for (int i = 0; i < kNumPyramidLevels; ++i) { 82 SAFE_DELETE(pyramid_sqrt2_[i * 2]); 83 SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]); 84 SAFE_DELETE(spatial_x_[i]); 85 SAFE_DELETE(spatial_y_[i]); 86 } 87 } 88 SetData(const uint8_t * const new_frame,const int stride,const int64_t timestamp,const int downsample_factor)89 void SetData(const uint8_t* const new_frame, const int stride, 90 const int64_t timestamp, const int downsample_factor) { 91 SetData(new_frame, NULL, stride, timestamp, downsample_factor); 92 } 93 SetData(const uint8_t * const new_frame,const uint8_t * const uv_frame,const int stride,const int64_t timestamp,const int downsample_factor)94 void SetData(const uint8_t* const new_frame, const uint8_t* const uv_frame, 95 const int stride, const int64_t timestamp, 96 const int downsample_factor) { 97 ResetComputationCache(); 98 99 timestamp_ = timestamp; 100 101 TimeLog("SetData!"); 102 103 pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor); 104 pyramid_sqrt2_computed_[0] = true; 105 TimeLog("Downsampled image"); 106 107 if (uv_frame != NULL) { 108 if (u_data_.get() == NULL) { 109 u_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_)); 110 v_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_)); 111 } 112 113 GetUV(uv_frame, u_data_.get(), v_data_.get()); 114 uv_data_computed_ = true; 115 TimeLog("Copied UV data"); 116 } else { 117 LOGV("No uv data!"); 118 } 119 120 #ifdef LOG_TIME 121 // If profiling is enabled, precompute here to make it easier to distinguish 122 // total costs. 123 Precompute(); 124 #endif 125 } 126 GetTimestamp()127 inline const uint64_t GetTimestamp() const { return timestamp_; } 128 GetImage()129 inline const Image<uint8_t>* GetImage() const { 130 SCHECK(pyramid_sqrt2_computed_[0], "image not set!"); 131 return pyramid_sqrt2_[0]; 132 } 133 GetPyramidSqrt2Level(const int level)134 const Image<uint8_t>* GetPyramidSqrt2Level(const int level) const { 135 if (!pyramid_sqrt2_computed_[level]) { 136 SCHECK(level != 0, "Level equals 0!"); 137 if (level == 1) { 138 const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(0); 139 if (pyramid_sqrt2_[level] == NULL) { 140 const int new_width = 141 (static_cast<int>(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2; 142 const int new_height = 143 (static_cast<int>(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 * 144 2; 145 146 pyramid_sqrt2_[level] = new Image<uint8_t>(new_width, new_height); 147 } 148 pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level); 149 } else { 150 const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(level - 2); 151 if (pyramid_sqrt2_[level] == NULL) { 152 pyramid_sqrt2_[level] = new Image<uint8_t>( 153 upper_level.GetWidth() / 2, upper_level.GetHeight() / 2); 154 } 155 pyramid_sqrt2_[level]->DownsampleAveraged( 156 upper_level.data(), upper_level.stride(), 2); 157 } 158 pyramid_sqrt2_computed_[level] = true; 159 } 160 return pyramid_sqrt2_[level]; 161 } 162 GetSpatialX(const int level)163 inline const Image<int32_t>* GetSpatialX(const int level) const { 164 if (!spatial_x_computed_[level]) { 165 const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2); 166 if (spatial_x_[level] == NULL) { 167 spatial_x_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight()); 168 } 169 spatial_x_[level]->DerivativeX(src); 170 spatial_x_computed_[level] = true; 171 } 172 return spatial_x_[level]; 173 } 174 GetSpatialY(const int level)175 inline const Image<int32_t>* GetSpatialY(const int level) const { 176 if (!spatial_y_computed_[level]) { 177 const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2); 178 if (spatial_y_[level] == NULL) { 179 spatial_y_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight()); 180 } 181 spatial_y_[level]->DerivativeY(src); 182 spatial_y_computed_[level] = true; 183 } 184 return spatial_y_[level]; 185 } 186 187 // The integral image is currently only used for object detection, so lazily 188 // initialize it on request. GetIntegralImage()189 inline const IntegralImage* GetIntegralImage() const { 190 if (integral_image_.get() == NULL) { 191 integral_image_.reset(new IntegralImage(image_)); 192 } else if (!integral_image_computed_) { 193 integral_image_->Recompute(image_); 194 } 195 integral_image_computed_ = true; 196 return integral_image_.get(); 197 } 198 GetU()199 inline const Image<uint8_t>* GetU() const { 200 SCHECK(uv_data_computed_, "UV data not provided!"); 201 return u_data_.get(); 202 } 203 GetV()204 inline const Image<uint8_t>* GetV() const { 205 SCHECK(uv_data_computed_, "UV data not provided!"); 206 return v_data_.get(); 207 } 208 209 private: Precompute()210 void Precompute() { 211 // Create the smoothed pyramids. 212 for (int i = 0; i < kNumPyramidLevels * 2; i += 2) { 213 (void) GetPyramidSqrt2Level(i); 214 } 215 TimeLog("Created smoothed pyramids"); 216 217 // Create the smoothed pyramids. 218 for (int i = 1; i < kNumPyramidLevels * 2; i += 2) { 219 (void) GetPyramidSqrt2Level(i); 220 } 221 TimeLog("Created smoothed sqrt pyramids"); 222 223 // Create the spatial derivatives for frame 1. 224 for (int i = 0; i < kNumPyramidLevels; ++i) { 225 (void) GetSpatialX(i); 226 (void) GetSpatialY(i); 227 } 228 TimeLog("Created spatial derivatives"); 229 230 (void) GetIntegralImage(); 231 TimeLog("Got integral image!"); 232 } 233 234 const int uv_frame_width_; 235 const int uv_frame_height_; 236 237 int64_t timestamp_; 238 239 Image<uint8_t> image_; 240 241 bool uv_data_computed_; 242 std::unique_ptr<Image<uint8_t> > u_data_; 243 std::unique_ptr<Image<uint8_t> > v_data_; 244 245 mutable bool spatial_x_computed_[kNumPyramidLevels]; 246 mutable Image<int32_t>* spatial_x_[kNumPyramidLevels]; 247 248 mutable bool spatial_y_computed_[kNumPyramidLevels]; 249 mutable Image<int32_t>* spatial_y_[kNumPyramidLevels]; 250 251 // Mutable so the lazy initialization can work when this class is const. 252 // Whether or not the integral image has been computed for the current image. 253 mutable bool integral_image_computed_; 254 mutable std::unique_ptr<IntegralImage> integral_image_; 255 256 mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2]; 257 mutable Image<uint8_t>* pyramid_sqrt2_[kNumPyramidLevels * 2]; 258 259 TF_DISALLOW_COPY_AND_ASSIGN(ImageData); 260 }; 261 262 } // namespace tf_tracking 263 264 #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ 265