• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_
17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/status/status.h"
23 #include "tensorflow/lite/core/api/op_resolver.h"
24 #include "tensorflow_lite_support/cc/port/statusor.h"
25 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
26 #include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h"
27 #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
28 #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
29 #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
30 #include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h"
31 #include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h"
32 
33 namespace tflite {
34 namespace task {
35 namespace vision {
36 
37 // Performs segmentation on images.
38 //
39 // The API expects a TFLite model with optional, but strongly recommended,
40 // TFLite Model Metadata.
41 //
42 // Input tensor:
43 //   (kTfLiteUInt8/kTfLiteFloat32)
44 //    - image input of size `[batch x height x width x channels]`.
45 //    - batch inference is not supported (`batch` is required to be 1).
46 //    - only RGB inputs are supported (`channels` is required to be 3).
47 //    - if type is kTfLiteFloat32, NormalizationOptions are required to be
48 //      attached to the metadata for input normalization.
49 // Output tensor:
50 //   (kTfLiteUInt8/kTfLiteFloat32)
51 //    - tensor of size `[batch x mask_height x mask_width x num_classes]`, where
52 //      `batch` is required to be 1, `mask_width` and `mask_height` are the
53 //      dimensions of the segmentation masks produced by the model, and
54 //      `num_classes` is the number of classes supported by the model.
55 //    - optional (but recommended) label map(s) can be attached as
56 //      AssociatedFile-s with type TENSOR_AXIS_LABELS, containing one label per
57 //      line. The first such AssociatedFile (if any) is used to fill the
58 //      `class_name` field of the results. The `display_name` field is filled
59 //      from the AssociatedFile (if any) whose locale matches the
60 //      `display_names_locale` field of the `ImageSegmenterOptions` used at
61 //      creation time ("en" by default, i.e. English). If none of these are
62 //      available, only the `index` field of the results will be filled.
63 //
64 // An example of such model can be found at:
65 // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1
66 //
67 // A CLI demo tool is available for easily trying out this API, and provides
68 // example usage. See:
69 // examples/task/vision/desktop/image_segmenter_demo.cc
70 class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> {
71  public:
72   using BaseVisionTaskApi::BaseVisionTaskApi;
73 
74   // Creates an ImageSegmenter from the provided options. A non-default
75   // OpResolver can be specified in order to support custom Ops or specify a
76   // subset of built-in Ops.
77   static tflite::support::StatusOr<std::unique_ptr<ImageSegmenter>>
78   CreateFromOptions(
79       const ImageSegmenterOptions& options,
80       std::unique_ptr<tflite::OpResolver> resolver =
81           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
82 
83   // Performs actual segmentation on the provided FrameBuffer.
84   //
85   // The FrameBuffer can be of any size and any of the supported formats, i.e.
86   // RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before
87   // inference in order to (and in this order):
88   // - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to
89   //   the dimensions of the model input tensor,
90   // - convert it to the colorspace of the input tensor (i.e. RGB, which is the
91   //   only supported colorspace for now),
92   // - rotate it according to its `Orientation` so that inference is performed
93   //   on an "upright" image.
94   //
95   // IMPORTANT: the returned segmentation masks are not direcly suited for
96   // display, in particular:
97   // * they are relative to the unrotated input frame, i.e. *not* taking into
98   //   account the `Orientation` flag of the input FrameBuffer,
99   // * their dimensions are intrinsic to the model, i.e. *not* dependent on the
100   //   input FrameBuffer dimensions.
101   //
102   // Example of such post-processing, assuming:
103   // * an input FrameBuffer with width=640, height=480, orientation=kLeftBottom
104   //   (i.e. the image will be rotated 90° clockwise during preprocessing to
105   //   make it "upright"),
106   // * a model outputting masks of size 224x224.
107   // In order to be directly displayable on top of the input image assumed to
108   // be displayed *with* the `Orientation` flag taken into account according to
109   // the EXIF specification (http://jpegclub.org/exif_orientation.html), the
110   // masks need to be:
111   // * re-scaled to 640 x 480,
112   // * then rotated 90° clockwise.
113   tflite::support::StatusOr<SegmentationResult> Segment(
114       const FrameBuffer& frame_buffer);
115 
116  protected:
117   // Post-processing to transform the raw model outputs into segmentation
118   // results.
119   tflite::support::StatusOr<SegmentationResult> Postprocess(
120       const std::vector<const TfLiteTensor*>& output_tensors,
121       const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
122 
123   // Performs sanity checks on the provided ImageSegmenterOptions.
124   static absl::Status SanityCheckOptions(const ImageSegmenterOptions& options);
125 
126   // Initializes the Segmenter from the provided ImageSegmenterOptions, whose
127   // ownership is transferred to this object.
128   absl::Status Init(std::unique_ptr<ImageSegmenterOptions> options);
129 
130   // Performs pre-initialization actions.
131   virtual absl::Status PreInit();
132 
133   // The options used for building this image segmenter.
134   std::unique_ptr<ImageSegmenterOptions> options_;
135 
136   // The label map, extracted from the TFLite Model Metadata.
137   std::vector<LabelMapItem> label_map_;
138 
139  private:
140   // Performs sanity checks on the model outputs and extracts their metadata.
141   absl::Status CheckAndSetOutputs();
142 
143   // Initializes the colored labels list from `label_map_` and stores it in
144   // `colored_labels_`.
145   absl::Status InitColoredLabels();
146 
147   // Returns the output confidence at coordinates {x, y, depth}, dequantizing
148   // on-the-fly if needed (i.e. if `has_uint8_outputs_` is true).
149   float GetOutputConfidence(const TfLiteTensor& output_tensor, int x, int y,
150                             int depth);
151 
152   // Prebuilt list of ColoredLabel attached to each Segmentation result. The
153   // i-th item in this list corresponds to the i-th label map item.
154   std::vector<Segmentation::ColoredLabel> colored_labels_;
155 
156   // Whether the model features quantized inference type (QUANTIZED_UINT8). This
157   // is currently detected by checking if all output tensors data type is uint8.
158   bool has_uint8_outputs_;
159 
160   // Expected output width.
161   int output_width_;
162   // Expected output height.
163   int output_height_;
164   // Expected output depth. This corresponds to the number of supported classes.
165   int output_depth_;
166 };
167 
168 }  // namespace vision
169 }  // namespace task
170 }  // namespace tflite
171 
172 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_
173