• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_
17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_
18 
19 #include <array>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/status/status.h"
26 #include "absl/time/clock.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow_lite_support/cc/common.h"
29 #include "tensorflow_lite_support/cc/port/integral_types.h"
30 #include "tensorflow_lite_support/cc/port/status_macros.h"
31 #include "tensorflow_lite_support/cc/task/core/base_task_api.h"
32 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
33 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
34 #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
35 #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
36 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
37 #include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h"
38 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
39 
40 namespace tflite {
41 namespace task {
42 namespace vision {
43 
44 // Base class providing common logic for vision models.
45 template <class OutputType>
46 class BaseVisionTaskApi
47     : public tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
48                                              const BoundingBox&> {
49  public:
BaseVisionTaskApi(std::unique_ptr<core::TfLiteEngine> engine)50   explicit BaseVisionTaskApi(std::unique_ptr<core::TfLiteEngine> engine)
51       : tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
52                                         const BoundingBox&>(std::move(engine)) {
53   }
54   // BaseVisionTaskApi is neither copyable nor movable.
55   BaseVisionTaskApi(const BaseVisionTaskApi&) = delete;
56   BaseVisionTaskApi& operator=(const BaseVisionTaskApi&) = delete;
57 
58   // Number of bytes required for 8-bit per pixel RGB color space.
59   static constexpr int kRgbPixelBytes = 3;
60 
61   // Sets the ProcessEngine used for image pre-processing. Must be called before
62   // any inference is performed. Can be called between inferences to override
63   // the current process engine.
SetProcessEngine(const FrameBufferUtils::ProcessEngine & process_engine)64   void SetProcessEngine(const FrameBufferUtils::ProcessEngine& process_engine) {
65     frame_buffer_utils_ = FrameBufferUtils::Create(process_engine);
66   }
67 
68  protected:
69   using tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
70                                         const BoundingBox&>::engine_;
71 
72   // Checks input tensor and metadata (if any) are valid, or return an error
73   // otherwise. This must be called once at initialization time, before running
74   // inference, as it is a prerequisite for `Preprocess`.
75   // Note: the underlying interpreter and metadata extractor are assumed to be
76   // already successfully initialized before calling this method.
CheckAndSetInputs()77   virtual absl::Status CheckAndSetInputs() {
78     ASSIGN_OR_RETURN(
79         ImageTensorSpecs input_specs,
80         BuildInputImageTensorSpecs(*engine_->interpreter(),
81                                    *engine_->metadata_extractor()));
82 
83     if (input_specs.color_space != tflite::ColorSpaceType_RGB) {
84       return tflite::support::CreateStatusWithPayload(
85           absl::StatusCode::kUnimplemented,
86           "BaseVisionTaskApi only supports RGB color space for now.");
87     }
88 
89     input_specs_ = absl::make_unique<ImageTensorSpecs>(input_specs);
90 
91     return absl::OkStatus();
92   }
93 
94   // Performs image preprocessing on the input frame buffer over the region of
95   // interest so that it fits model requirements (e.g. upright 224x224 RGB) and
96   // populate the corresponding input tensor. This is performed by (in this
97   // order):
98   // - cropping the frame buffer to the region of interest (which, in most
99   //   cases, just covers the entire input image),
100   // - resizing it (with bilinear interpolation, aspect-ratio *not* preserved)
101   //   to the dimensions of the model input tensor,
102   // - converting it to the colorspace of the input tensor (i.e. RGB, which is
103   //   the only supported colorspace for now),
104   // - rotating it according to its `Orientation` so that inference is performed
105   //   on an "upright" image.
106   //
107   // IMPORTANT: as a consequence of cropping occurring first, the provided
108   // region of interest is expressed in the unrotated frame of reference
109   // coordinates system, i.e. in `[0, frame_buffer.width) x [0,
110   // frame_buffer.height)`, which are the dimensions of the underlying
111   // `frame_buffer` data before any `Orientation` flag gets applied. Also, the
112   // region of interest is not clamped, so this method will return a non-ok
113   // status if the region is out of these bounds.
Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const FrameBuffer & frame_buffer,const BoundingBox & roi)114   absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
115                           const FrameBuffer& frame_buffer,
116                           const BoundingBox& roi) override {
117     if (input_specs_ == nullptr) {
118       return tflite::support::CreateStatusWithPayload(
119           absl::StatusCode::kInternal,
120           "Uninitialized input tensor specs: CheckAndSetInputs must be called "
121           "at initialization time.");
122     }
123 
124     if (frame_buffer_utils_ == nullptr) {
125       return tflite::support::CreateStatusWithPayload(
126           absl::StatusCode::kInternal,
127           "Uninitialized frame buffer utils: SetProcessEngine must be called "
128           "at initialization time.");
129     }
130 
131     if (input_tensors.size() != 1) {
132       return tflite::support::CreateStatusWithPayload(
133           absl::StatusCode::kInternal, "A single input tensor is expected.");
134     }
135 
136     // Input data to be normalized (if needed) and used for inference. In most
137     // cases, this is the result of image preprocessing. In case no image
138     // preprocessing is needed (see below), this points to the input frame
139     // buffer raw data.
140     const uint8* input_data;
141     size_t input_data_byte_size;
142 
143     // Optional buffers in case image preprocessing is needed.
144     std::unique_ptr<FrameBuffer> preprocessed_frame_buffer;
145     std::vector<uint8> preprocessed_data;
146 
147     if (IsImagePreprocessingNeeded(frame_buffer, roi)) {
148       // Preprocess input image to fit model requirements.
149       // For now RGB is the only color space supported, which is ensured by
150       // `CheckAndSetInputs`.
151       FrameBuffer::Dimension to_buffer_dimension = {input_specs_->image_width,
152                                                     input_specs_->image_height};
153       input_data_byte_size =
154           GetBufferByteSize(to_buffer_dimension, FrameBuffer::Format::kRGB);
155       preprocessed_data.resize(input_data_byte_size / sizeof(uint8), 0);
156       input_data = preprocessed_data.data();
157 
158       FrameBuffer::Plane preprocessed_plane = {
159           /*buffer=*/preprocessed_data.data(),
160           /*stride=*/{input_specs_->image_width * kRgbPixelBytes,
161                       kRgbPixelBytes}};
162       preprocessed_frame_buffer = FrameBuffer::Create(
163           {preprocessed_plane}, to_buffer_dimension, FrameBuffer::Format::kRGB,
164           FrameBuffer::Orientation::kTopLeft);
165 
166       RETURN_IF_ERROR(frame_buffer_utils_->Preprocess(
167           frame_buffer, roi, preprocessed_frame_buffer.get()));
168     } else {
169       // Input frame buffer already targets model requirements: skip image
170       // preprocessing. For RGB, the data is always stored in a single plane.
171       input_data = frame_buffer.plane(0).buffer;
172       input_data_byte_size = frame_buffer.plane(0).stride.row_stride_bytes *
173                              frame_buffer.dimension().height;
174     }
175 
176     // Then normalize pixel data (if needed) and populate the input tensor.
177     switch (input_specs_->tensor_type) {
178       case kTfLiteUInt8:
179         if (input_tensors[0]->bytes != input_data_byte_size) {
180           return tflite::support::CreateStatusWithPayload(
181               absl::StatusCode::kInternal,
182               "Size mismatch or unsupported padding bytes between pixel data "
183               "and input tensor.");
184         }
185         // No normalization required: directly populate data.
186         tflite::task::core::PopulateTensor(
187             input_data, input_data_byte_size / sizeof(uint8), input_tensors[0]);
188         break;
189       case kTfLiteFloat32: {
190         if (input_tensors[0]->bytes / sizeof(float) !=
191             input_data_byte_size / sizeof(uint8)) {
192           return tflite::support::CreateStatusWithPayload(
193               absl::StatusCode::kInternal,
194               "Size mismatch or unsupported padding bytes between pixel data "
195               "and input tensor.");
196         }
197         // Normalize and populate.
198         float* normalized_input_data =
199             tflite::task::core::AssertAndReturnTypedTensor<float>(
200                 input_tensors[0]);
201         const tflite::task::vision::NormalizationOptions&
202             normalization_options = input_specs_->normalization_options.value();
203         if (normalization_options.num_values == 1) {
204           float mean_value = normalization_options.mean_values[0];
205           float inv_std_value = (1.0f / normalization_options.std_values[0]);
206           for (int i = 0; i < input_data_byte_size / sizeof(uint8);
207                i++, input_data++, normalized_input_data++) {
208             *normalized_input_data =
209                 inv_std_value * (static_cast<float>(*input_data) - mean_value);
210           }
211         } else {
212           std::array<float, 3> inv_std_values = {
213               1.0f / normalization_options.std_values[0],
214               1.0f / normalization_options.std_values[1],
215               1.0f / normalization_options.std_values[2]};
216           for (int i = 0; i < input_data_byte_size / sizeof(uint8);
217                i++, input_data++, normalized_input_data++) {
218             *normalized_input_data = inv_std_values[i % 3] *
219                                      (static_cast<float>(*input_data) -
220                                       normalization_options.mean_values[i % 3]);
221           }
222         }
223         break;
224       }
225       case kTfLiteInt8:
226         return tflite::support::CreateStatusWithPayload(
227             absl::StatusCode::kUnimplemented,
228             "kTfLiteInt8 input type is not implemented yet.");
229       default:
230         return tflite::support::CreateStatusWithPayload(
231             absl::StatusCode::kInternal, "Unexpected input tensor type.");
232     }
233 
234     return absl::OkStatus();
235   }
236 
237   // Utils for input image preprocessing (resizing, colorspace conversion, etc).
238   std::unique_ptr<FrameBufferUtils> frame_buffer_utils_;
239 
240   // Parameters related to the input tensor which represents an image.
241   std::unique_ptr<ImageTensorSpecs> input_specs_;
242 
243  private:
244   // Returns false if image preprocessing could be skipped, true otherwise.
IsImagePreprocessingNeeded(const FrameBuffer & frame_buffer,const BoundingBox & roi)245   bool IsImagePreprocessingNeeded(const FrameBuffer& frame_buffer,
246                                   const BoundingBox& roi) {
247     // Is crop required?
248     if (roi.origin_x() != 0 || roi.origin_y() != 0 ||
249         roi.width() != frame_buffer.dimension().width ||
250         roi.height() != frame_buffer.dimension().height) {
251       return true;
252     }
253 
254     // Are image transformations required?
255     if (frame_buffer.orientation() != FrameBuffer::Orientation::kTopLeft ||
256         frame_buffer.format() != FrameBuffer::Format::kRGB ||
257         frame_buffer.dimension().width != input_specs_->image_width ||
258         frame_buffer.dimension().height != input_specs_->image_height) {
259       return true;
260     }
261 
262     return false;
263   }
264 };
265 
266 }  // namespace vision
267 }  // namespace task
268 }  // namespace tflite
269 
270 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_
271