1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_ 17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "absl/status/status.h" 23 #include "tensorflow/lite/core/api/op_resolver.h" 24 #include "tensorflow_lite_support/cc/port/statusor.h" 25 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" 26 #include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h" 27 #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" 28 #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" 29 #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" 30 #include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h" 31 #include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h" 32 33 namespace tflite { 34 namespace task { 35 namespace vision { 36 37 // Performs segmentation on images. 38 // 39 // The API expects a TFLite model with optional, but strongly recommended, 40 // TFLite Model Metadata. 41 // 42 // Input tensor: 43 // (kTfLiteUInt8/kTfLiteFloat32) 44 // - image input of size `[batch x height x width x channels]`. 45 // - batch inference is not supported (`batch` is required to be 1). 46 // - only RGB inputs are supported (`channels` is required to be 3). 47 // - if type is kTfLiteFloat32, NormalizationOptions are required to be 48 // attached to the metadata for input normalization. 49 // Output tensor: 50 // (kTfLiteUInt8/kTfLiteFloat32) 51 // - tensor of size `[batch x mask_height x mask_width x num_classes]`, where 52 // `batch` is required to be 1, `mask_width` and `mask_height` are the 53 // dimensions of the segmentation masks produced by the model, and 54 // `num_classes` is the number of classes supported by the model. 55 // - optional (but recommended) label map(s) can be attached as 56 // AssociatedFile-s with type TENSOR_AXIS_LABELS, containing one label per 57 // line. The first such AssociatedFile (if any) is used to fill the 58 // `class_name` field of the results. The `display_name` field is filled 59 // from the AssociatedFile (if any) whose locale matches the 60 // `display_names_locale` field of the `ImageSegmenterOptions` used at 61 // creation time ("en" by default, i.e. English). If none of these are 62 // available, only the `index` field of the results will be filled. 63 // 64 // An example of such model can be found at: 65 // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1 66 // 67 // A CLI demo tool is available for easily trying out this API, and provides 68 // example usage. See: 69 // examples/task/vision/desktop/image_segmenter_demo.cc 70 class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> { 71 public: 72 using BaseVisionTaskApi::BaseVisionTaskApi; 73 74 // Creates an ImageSegmenter from the provided options. A non-default 75 // OpResolver can be specified in order to support custom Ops or specify a 76 // subset of built-in Ops. 77 static tflite::support::StatusOr<std::unique_ptr<ImageSegmenter>> 78 CreateFromOptions( 79 const ImageSegmenterOptions& options, 80 std::unique_ptr<tflite::OpResolver> resolver = 81 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); 82 83 // Performs actual segmentation on the provided FrameBuffer. 84 // 85 // The FrameBuffer can be of any size and any of the supported formats, i.e. 86 // RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before 87 // inference in order to (and in this order): 88 // - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to 89 // the dimensions of the model input tensor, 90 // - convert it to the colorspace of the input tensor (i.e. RGB, which is the 91 // only supported colorspace for now), 92 // - rotate it according to its `Orientation` so that inference is performed 93 // on an "upright" image. 94 // 95 // IMPORTANT: the returned segmentation masks are not direcly suited for 96 // display, in particular: 97 // * they are relative to the unrotated input frame, i.e. *not* taking into 98 // account the `Orientation` flag of the input FrameBuffer, 99 // * their dimensions are intrinsic to the model, i.e. *not* dependent on the 100 // input FrameBuffer dimensions. 101 // 102 // Example of such post-processing, assuming: 103 // * an input FrameBuffer with width=640, height=480, orientation=kLeftBottom 104 // (i.e. the image will be rotated 90° clockwise during preprocessing to 105 // make it "upright"), 106 // * a model outputting masks of size 224x224. 107 // In order to be directly displayable on top of the input image assumed to 108 // be displayed *with* the `Orientation` flag taken into account according to 109 // the EXIF specification (http://jpegclub.org/exif_orientation.html), the 110 // masks need to be: 111 // * re-scaled to 640 x 480, 112 // * then rotated 90° clockwise. 113 tflite::support::StatusOr<SegmentationResult> Segment( 114 const FrameBuffer& frame_buffer); 115 116 protected: 117 // Post-processing to transform the raw model outputs into segmentation 118 // results. 119 tflite::support::StatusOr<SegmentationResult> Postprocess( 120 const std::vector<const TfLiteTensor*>& output_tensors, 121 const FrameBuffer& frame_buffer, const BoundingBox& roi) override; 122 123 // Performs sanity checks on the provided ImageSegmenterOptions. 124 static absl::Status SanityCheckOptions(const ImageSegmenterOptions& options); 125 126 // Initializes the Segmenter from the provided ImageSegmenterOptions, whose 127 // ownership is transferred to this object. 128 absl::Status Init(std::unique_ptr<ImageSegmenterOptions> options); 129 130 // Performs pre-initialization actions. 131 virtual absl::Status PreInit(); 132 133 // The options used for building this image segmenter. 134 std::unique_ptr<ImageSegmenterOptions> options_; 135 136 // The label map, extracted from the TFLite Model Metadata. 137 std::vector<LabelMapItem> label_map_; 138 139 private: 140 // Performs sanity checks on the model outputs and extracts their metadata. 141 absl::Status CheckAndSetOutputs(); 142 143 // Initializes the colored labels list from `label_map_` and stores it in 144 // `colored_labels_`. 145 absl::Status InitColoredLabels(); 146 147 // Returns the output confidence at coordinates {x, y, depth}, dequantizing 148 // on-the-fly if needed (i.e. if `has_uint8_outputs_` is true). 149 float GetOutputConfidence(const TfLiteTensor& output_tensor, int x, int y, 150 int depth); 151 152 // Prebuilt list of ColoredLabel attached to each Segmentation result. The 153 // i-th item in this list corresponds to the i-th label map item. 154 std::vector<Segmentation::ColoredLabel> colored_labels_; 155 156 // Whether the model features quantized inference type (QUANTIZED_UINT8). This 157 // is currently detected by checking if all output tensors data type is uint8. 158 bool has_uint8_outputs_; 159 160 // Expected output width. 161 int output_width_; 162 // Expected output height. 163 int output_height_; 164 // Expected output depth. This corresponds to the number of supported classes. 165 int output_depth_; 166 }; 167 168 } // namespace vision 169 } // namespace task 170 } // namespace tflite 171 172 #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_ 173