1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_ 17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_ 18 19 #include <array> 20 #include <memory> 21 #include <utility> 22 #include <vector> 23 24 #include "absl/memory/memory.h" 25 #include "absl/status/status.h" 26 #include "absl/time/clock.h" 27 #include "tensorflow/lite/c/common.h" 28 #include "tensorflow_lite_support/cc/common.h" 29 #include "tensorflow_lite_support/cc/port/integral_types.h" 30 #include "tensorflow_lite_support/cc/port/status_macros.h" 31 #include "tensorflow_lite_support/cc/task/core/base_task_api.h" 32 #include "tensorflow_lite_support/cc/task/core/task_utils.h" 33 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" 34 #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" 35 #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" 36 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" 37 #include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h" 38 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" 39 40 namespace tflite { 41 namespace task { 42 namespace vision { 43 44 // Base class providing common logic for vision models. 45 template <class OutputType> 46 class BaseVisionTaskApi 47 : public tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&, 48 const BoundingBox&> { 49 public: BaseVisionTaskApi(std::unique_ptr<core::TfLiteEngine> engine)50 explicit BaseVisionTaskApi(std::unique_ptr<core::TfLiteEngine> engine) 51 : tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&, 52 const BoundingBox&>(std::move(engine)) { 53 } 54 // BaseVisionTaskApi is neither copyable nor movable. 55 BaseVisionTaskApi(const BaseVisionTaskApi&) = delete; 56 BaseVisionTaskApi& operator=(const BaseVisionTaskApi&) = delete; 57 58 // Number of bytes required for 8-bit per pixel RGB color space. 59 static constexpr int kRgbPixelBytes = 3; 60 61 // Sets the ProcessEngine used for image pre-processing. Must be called before 62 // any inference is performed. Can be called between inferences to override 63 // the current process engine. SetProcessEngine(const FrameBufferUtils::ProcessEngine & process_engine)64 void SetProcessEngine(const FrameBufferUtils::ProcessEngine& process_engine) { 65 frame_buffer_utils_ = FrameBufferUtils::Create(process_engine); 66 } 67 68 protected: 69 using tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&, 70 const BoundingBox&>::engine_; 71 72 // Checks input tensor and metadata (if any) are valid, or return an error 73 // otherwise. This must be called once at initialization time, before running 74 // inference, as it is a prerequisite for `Preprocess`. 75 // Note: the underlying interpreter and metadata extractor are assumed to be 76 // already successfully initialized before calling this method. CheckAndSetInputs()77 virtual absl::Status CheckAndSetInputs() { 78 ASSIGN_OR_RETURN( 79 ImageTensorSpecs input_specs, 80 BuildInputImageTensorSpecs(*engine_->interpreter(), 81 *engine_->metadata_extractor())); 82 83 if (input_specs.color_space != tflite::ColorSpaceType_RGB) { 84 return tflite::support::CreateStatusWithPayload( 85 absl::StatusCode::kUnimplemented, 86 "BaseVisionTaskApi only supports RGB color space for now."); 87 } 88 89 input_specs_ = absl::make_unique<ImageTensorSpecs>(input_specs); 90 91 return absl::OkStatus(); 92 } 93 94 // Performs image preprocessing on the input frame buffer over the region of 95 // interest so that it fits model requirements (e.g. upright 224x224 RGB) and 96 // populate the corresponding input tensor. This is performed by (in this 97 // order): 98 // - cropping the frame buffer to the region of interest (which, in most 99 // cases, just covers the entire input image), 100 // - resizing it (with bilinear interpolation, aspect-ratio *not* preserved) 101 // to the dimensions of the model input tensor, 102 // - converting it to the colorspace of the input tensor (i.e. RGB, which is 103 // the only supported colorspace for now), 104 // - rotating it according to its `Orientation` so that inference is performed 105 // on an "upright" image. 106 // 107 // IMPORTANT: as a consequence of cropping occurring first, the provided 108 // region of interest is expressed in the unrotated frame of reference 109 // coordinates system, i.e. in `[0, frame_buffer.width) x [0, 110 // frame_buffer.height)`, which are the dimensions of the underlying 111 // `frame_buffer` data before any `Orientation` flag gets applied. Also, the 112 // region of interest is not clamped, so this method will return a non-ok 113 // status if the region is out of these bounds. Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const FrameBuffer & frame_buffer,const BoundingBox & roi)114 absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, 115 const FrameBuffer& frame_buffer, 116 const BoundingBox& roi) override { 117 if (input_specs_ == nullptr) { 118 return tflite::support::CreateStatusWithPayload( 119 absl::StatusCode::kInternal, 120 "Uninitialized input tensor specs: CheckAndSetInputs must be called " 121 "at initialization time."); 122 } 123 124 if (frame_buffer_utils_ == nullptr) { 125 return tflite::support::CreateStatusWithPayload( 126 absl::StatusCode::kInternal, 127 "Uninitialized frame buffer utils: SetProcessEngine must be called " 128 "at initialization time."); 129 } 130 131 if (input_tensors.size() != 1) { 132 return tflite::support::CreateStatusWithPayload( 133 absl::StatusCode::kInternal, "A single input tensor is expected."); 134 } 135 136 // Input data to be normalized (if needed) and used for inference. In most 137 // cases, this is the result of image preprocessing. In case no image 138 // preprocessing is needed (see below), this points to the input frame 139 // buffer raw data. 140 const uint8* input_data; 141 size_t input_data_byte_size; 142 143 // Optional buffers in case image preprocessing is needed. 144 std::unique_ptr<FrameBuffer> preprocessed_frame_buffer; 145 std::vector<uint8> preprocessed_data; 146 147 if (IsImagePreprocessingNeeded(frame_buffer, roi)) { 148 // Preprocess input image to fit model requirements. 149 // For now RGB is the only color space supported, which is ensured by 150 // `CheckAndSetInputs`. 151 FrameBuffer::Dimension to_buffer_dimension = {input_specs_->image_width, 152 input_specs_->image_height}; 153 input_data_byte_size = 154 GetBufferByteSize(to_buffer_dimension, FrameBuffer::Format::kRGB); 155 preprocessed_data.resize(input_data_byte_size / sizeof(uint8), 0); 156 input_data = preprocessed_data.data(); 157 158 FrameBuffer::Plane preprocessed_plane = { 159 /*buffer=*/preprocessed_data.data(), 160 /*stride=*/{input_specs_->image_width * kRgbPixelBytes, 161 kRgbPixelBytes}}; 162 preprocessed_frame_buffer = FrameBuffer::Create( 163 {preprocessed_plane}, to_buffer_dimension, FrameBuffer::Format::kRGB, 164 FrameBuffer::Orientation::kTopLeft); 165 166 RETURN_IF_ERROR(frame_buffer_utils_->Preprocess( 167 frame_buffer, roi, preprocessed_frame_buffer.get())); 168 } else { 169 // Input frame buffer already targets model requirements: skip image 170 // preprocessing. For RGB, the data is always stored in a single plane. 171 input_data = frame_buffer.plane(0).buffer; 172 input_data_byte_size = frame_buffer.plane(0).stride.row_stride_bytes * 173 frame_buffer.dimension().height; 174 } 175 176 // Then normalize pixel data (if needed) and populate the input tensor. 177 switch (input_specs_->tensor_type) { 178 case kTfLiteUInt8: 179 if (input_tensors[0]->bytes != input_data_byte_size) { 180 return tflite::support::CreateStatusWithPayload( 181 absl::StatusCode::kInternal, 182 "Size mismatch or unsupported padding bytes between pixel data " 183 "and input tensor."); 184 } 185 // No normalization required: directly populate data. 186 tflite::task::core::PopulateTensor( 187 input_data, input_data_byte_size / sizeof(uint8), input_tensors[0]); 188 break; 189 case kTfLiteFloat32: { 190 if (input_tensors[0]->bytes / sizeof(float) != 191 input_data_byte_size / sizeof(uint8)) { 192 return tflite::support::CreateStatusWithPayload( 193 absl::StatusCode::kInternal, 194 "Size mismatch or unsupported padding bytes between pixel data " 195 "and input tensor."); 196 } 197 // Normalize and populate. 198 float* normalized_input_data = 199 tflite::task::core::AssertAndReturnTypedTensor<float>( 200 input_tensors[0]); 201 const tflite::task::vision::NormalizationOptions& 202 normalization_options = input_specs_->normalization_options.value(); 203 if (normalization_options.num_values == 1) { 204 float mean_value = normalization_options.mean_values[0]; 205 float inv_std_value = (1.0f / normalization_options.std_values[0]); 206 for (int i = 0; i < input_data_byte_size / sizeof(uint8); 207 i++, input_data++, normalized_input_data++) { 208 *normalized_input_data = 209 inv_std_value * (static_cast<float>(*input_data) - mean_value); 210 } 211 } else { 212 std::array<float, 3> inv_std_values = { 213 1.0f / normalization_options.std_values[0], 214 1.0f / normalization_options.std_values[1], 215 1.0f / normalization_options.std_values[2]}; 216 for (int i = 0; i < input_data_byte_size / sizeof(uint8); 217 i++, input_data++, normalized_input_data++) { 218 *normalized_input_data = inv_std_values[i % 3] * 219 (static_cast<float>(*input_data) - 220 normalization_options.mean_values[i % 3]); 221 } 222 } 223 break; 224 } 225 case kTfLiteInt8: 226 return tflite::support::CreateStatusWithPayload( 227 absl::StatusCode::kUnimplemented, 228 "kTfLiteInt8 input type is not implemented yet."); 229 default: 230 return tflite::support::CreateStatusWithPayload( 231 absl::StatusCode::kInternal, "Unexpected input tensor type."); 232 } 233 234 return absl::OkStatus(); 235 } 236 237 // Utils for input image preprocessing (resizing, colorspace conversion, etc). 238 std::unique_ptr<FrameBufferUtils> frame_buffer_utils_; 239 240 // Parameters related to the input tensor which represents an image. 241 std::unique_ptr<ImageTensorSpecs> input_specs_; 242 243 private: 244 // Returns false if image preprocessing could be skipped, true otherwise. IsImagePreprocessingNeeded(const FrameBuffer & frame_buffer,const BoundingBox & roi)245 bool IsImagePreprocessingNeeded(const FrameBuffer& frame_buffer, 246 const BoundingBox& roi) { 247 // Is crop required? 248 if (roi.origin_x() != 0 || roi.origin_y() != 0 || 249 roi.width() != frame_buffer.dimension().width || 250 roi.height() != frame_buffer.dimension().height) { 251 return true; 252 } 253 254 // Are image transformations required? 255 if (frame_buffer.orientation() != FrameBuffer::Orientation::kTopLeft || 256 frame_buffer.format() != FrameBuffer::Format::kRGB || 257 frame_buffer.dimension().width != input_specs_->image_width || 258 frame_buffer.dimension().height != input_specs_->image_height) { 259 return true; 260 } 261 262 return false; 263 } 264 }; 265 266 } // namespace vision 267 } // namespace task 268 } // namespace tflite 269 270 #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_ 271