1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/vision/image_segmenter.h"
17
18 #include <algorithm>
19
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_format.h"
22 #include "absl/strings/string_view.h"
23 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow_lite_support/cc/common.h"
26 #include "tensorflow_lite_support/cc/port/integral_types.h"
27 #include "tensorflow_lite_support/cc/port/status_macros.h"
28 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
29 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
30 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
31 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
32 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
33 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
34
35 namespace tflite {
36 namespace task {
37 namespace vision {
38
39 namespace {
40
41 using ::absl::StatusCode;
42 using ::tflite::TensorMetadata;
43 using ::tflite::metadata::ModelMetadataExtractor;
44 using ::tflite::support::CreateStatusWithPayload;
45 using ::tflite::support::StatusOr;
46 using ::tflite::support::TfLiteSupportStatus;
47 using ::tflite::task::core::AssertAndReturnTypedTensor;
48 using ::tflite::task::core::TaskAPIFactory;
49 using ::tflite::task::core::TfLiteEngine;
50
51 // The maximum number of labels allowed in the labelmap. This is because so far
52 // segmentation masks are stored with 8 bit per pixel (flattened byte array).
53 constexpr uint32 kMaxNumClasses = 256;
54
55 // TODO(b/)
56 // The colormap used to fill `ColoredLabel`-s, as a flattened array of 256 {R,
57 // G, B} components.
58 constexpr uint8 kColorMap[768] = {
59 0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128,
60 128, 0, 128, 0, 128, 128, 128, 128, 128, 64, 0, 0, 192, 0, 0,
61 64, 128, 0, 192, 128, 0, 64, 0, 128, 192, 0, 128, 64, 128, 128,
62 192, 128, 128, 0, 64, 0, 128, 64, 0, 0, 192, 0, 128, 192, 0,
63 0, 64, 128, 128, 64, 128, 0, 192, 128, 128, 192, 128, 64, 64, 0,
64 192, 64, 0, 64, 192, 0, 192, 192, 0, 64, 64, 128, 192, 64, 128,
65 64, 192, 128, 192, 192, 128, 0, 0, 64, 128, 0, 64, 0, 128, 64,
66 128, 128, 64, 0, 0, 192, 128, 0, 192, 0, 128, 192, 128, 128, 192,
67 64, 0, 64, 192, 0, 64, 64, 128, 64, 192, 128, 64, 64, 0, 192,
68 192, 0, 192, 64, 128, 192, 192, 128, 192, 0, 64, 64, 128, 64, 64,
69 0, 192, 64, 128, 192, 64, 0, 64, 192, 128, 64, 192, 0, 192, 192,
70 128, 192, 192, 64, 64, 64, 192, 64, 64, 64, 192, 64, 192, 192, 64,
71 64, 64, 192, 192, 64, 192, 64, 192, 192, 192, 192, 192, 32, 0, 0,
72 160, 0, 0, 32, 128, 0, 160, 128, 0, 32, 0, 128, 160, 0, 128,
73 32, 128, 128, 160, 128, 128, 96, 0, 0, 224, 0, 0, 96, 128, 0,
74 224, 128, 0, 96, 0, 128, 224, 0, 128, 96, 128, 128, 224, 128, 128,
75 32, 64, 0, 160, 64, 0, 32, 192, 0, 160, 192, 0, 32, 64, 128,
76 160, 64, 128, 32, 192, 128, 160, 192, 128, 96, 64, 0, 224, 64, 0,
77 96, 192, 0, 224, 192, 0, 96, 64, 128, 224, 64, 128, 96, 192, 128,
78 224, 192, 128, 32, 0, 64, 160, 0, 64, 32, 128, 64, 160, 128, 64,
79 32, 0, 192, 160, 0, 192, 32, 128, 192, 160, 128, 192, 96, 0, 64,
80 224, 0, 64, 96, 128, 64, 224, 128, 64, 96, 0, 192, 224, 0, 192,
81 96, 128, 192, 224, 128, 192, 32, 64, 64, 160, 64, 64, 32, 192, 64,
82 160, 192, 64, 32, 64, 192, 160, 64, 192, 32, 192, 192, 160, 192, 192,
83 96, 64, 64, 224, 64, 64, 96, 192, 64, 224, 192, 64, 96, 64, 192,
84 224, 64, 192, 96, 192, 192, 224, 192, 192, 0, 32, 0, 128, 32, 0,
85 0, 160, 0, 128, 160, 0, 0, 32, 128, 128, 32, 128, 0, 160, 128,
86 128, 160, 128, 64, 32, 0, 192, 32, 0, 64, 160, 0, 192, 160, 0,
87 64, 32, 128, 192, 32, 128, 64, 160, 128, 192, 160, 128, 0, 96, 0,
88 128, 96, 0, 0, 224, 0, 128, 224, 0, 0, 96, 128, 128, 96, 128,
89 0, 224, 128, 128, 224, 128, 64, 96, 0, 192, 96, 0, 64, 224, 0,
90 192, 224, 0, 64, 96, 128, 192, 96, 128, 64, 224, 128, 192, 224, 128,
91 0, 32, 64, 128, 32, 64, 0, 160, 64, 128, 160, 64, 0, 32, 192,
92 128, 32, 192, 0, 160, 192, 128, 160, 192, 64, 32, 64, 192, 32, 64,
93 64, 160, 64, 192, 160, 64, 64, 32, 192, 192, 32, 192, 64, 160, 192,
94 192, 160, 192, 0, 96, 64, 128, 96, 64, 0, 224, 64, 128, 224, 64,
95 0, 96, 192, 128, 96, 192, 0, 224, 192, 128, 224, 192, 64, 96, 64,
96 192, 96, 64, 64, 224, 64, 192, 224, 64, 64, 96, 192, 192, 96, 192,
97 64, 224, 192, 192, 224, 192, 32, 32, 0, 160, 32, 0, 32, 160, 0,
98 160, 160, 0, 32, 32, 128, 160, 32, 128, 32, 160, 128, 160, 160, 128,
99 96, 32, 0, 224, 32, 0, 96, 160, 0, 224, 160, 0, 96, 32, 128,
100 224, 32, 128, 96, 160, 128, 224, 160, 128, 32, 96, 0, 160, 96, 0,
101 32, 224, 0, 160, 224, 0, 32, 96, 128, 160, 96, 128, 32, 224, 128,
102 160, 224, 128, 96, 96, 0, 224, 96, 0, 96, 224, 0, 224, 224, 0,
103 96, 96, 128, 224, 96, 128, 96, 224, 128, 224, 224, 128, 32, 32, 64,
104 160, 32, 64, 32, 160, 64, 160, 160, 64, 32, 32, 192, 160, 32, 192,
105 32, 160, 192, 160, 160, 192, 96, 32, 64, 224, 32, 64, 96, 160, 64,
106 224, 160, 64, 96, 32, 192, 224, 32, 192, 96, 160, 192, 224, 160, 192,
107 32, 96, 64, 160, 96, 64, 32, 224, 64, 160, 224, 64, 32, 96, 192,
108 160, 96, 192, 32, 224, 192, 160, 224, 192, 96, 96, 64, 224, 96, 64,
109 96, 224, 64, 224, 224, 64, 96, 96, 192, 224, 96, 192, 96, 224, 192,
110 224, 224, 192};
111
GetLabelMapIfAny(const ModelMetadataExtractor & metadata_extractor,const TensorMetadata & tensor_metadata,absl::string_view locale)112 StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
113 const ModelMetadataExtractor& metadata_extractor,
114 const TensorMetadata& tensor_metadata, absl::string_view locale) {
115 const std::string labels_filename =
116 ModelMetadataExtractor::FindFirstAssociatedFileName(
117 tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS);
118 if (labels_filename.empty()) {
119 return std::vector<LabelMapItem>();
120 }
121 ASSIGN_OR_RETURN(absl::string_view labels_file,
122 metadata_extractor.GetAssociatedFile(labels_filename));
123 const std::string display_names_filename =
124 ModelMetadataExtractor::FindFirstAssociatedFileName(
125 tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS,
126 locale);
127 absl::string_view display_names_file = nullptr;
128 if (!display_names_filename.empty()) {
129 ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
130 display_names_filename));
131 }
132 return BuildLabelMapFromFiles(labels_file, display_names_file);
133 }
134
135 } // namespace
136
137 /* static */
SanityCheckOptions(const ImageSegmenterOptions & options)138 absl::Status ImageSegmenter::SanityCheckOptions(
139 const ImageSegmenterOptions& options) {
140 if (!options.has_model_file_with_metadata()) {
141 return CreateStatusWithPayload(
142 StatusCode::kInvalidArgument,
143 "Missing mandatory `model_file_with_metadata` field",
144 TfLiteSupportStatus::kInvalidArgumentError);
145 }
146 if (options.output_type() == ImageSegmenterOptions::UNSPECIFIED) {
147 return CreateStatusWithPayload(
148 StatusCode::kInvalidArgument,
149 "ImageSegmenterOptions: `output_type` must not be UNSPECIFIED",
150 TfLiteSupportStatus::kInvalidArgumentError);
151 }
152 if (options.num_threads() == 0 || options.num_threads() < -1) {
153 return CreateStatusWithPayload(
154 StatusCode::kInvalidArgument,
155 "`num_threads` must be greater than 0 or equal to -1.",
156 TfLiteSupportStatus::kInvalidArgumentError);
157 }
158 return absl::OkStatus();
159 }
160
CreateFromOptions(const ImageSegmenterOptions & options,std::unique_ptr<tflite::OpResolver> resolver)161 StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::CreateFromOptions(
162 const ImageSegmenterOptions& options,
163 std::unique_ptr<tflite::OpResolver> resolver) {
164 RETURN_IF_ERROR(SanityCheckOptions(options));
165
166 // Copy options to ensure the ExternalFile outlives the constructed object.
167 auto options_copy = absl::make_unique<ImageSegmenterOptions>(options);
168
169 ASSIGN_OR_RETURN(auto image_segmenter,
170 TaskAPIFactory::CreateFromExternalFileProto<ImageSegmenter>(
171 &options_copy->model_file_with_metadata(),
172 std::move(resolver), options_copy->num_threads()));
173
174 RETURN_IF_ERROR(image_segmenter->Init(std::move(options_copy)));
175
176 return image_segmenter;
177 }
178
Init(std::unique_ptr<ImageSegmenterOptions> options)179 absl::Status ImageSegmenter::Init(
180 std::unique_ptr<ImageSegmenterOptions> options) {
181 // Set options.
182 options_ = std::move(options);
183
184 // Perform pre-initialization actions (by default, sets the process engine for
185 // image pre-processing to kLibyuv as a sane default).
186 RETURN_IF_ERROR(PreInit());
187
188 // Sanity check and set inputs and outputs.
189 RETURN_IF_ERROR(CheckAndSetInputs());
190 RETURN_IF_ERROR(CheckAndSetOutputs());
191
192 // Initialize colored_labels_ once and for all.
193 RETURN_IF_ERROR(InitColoredLabels());
194
195 return absl::OkStatus();
196 }
197
PreInit()198 absl::Status ImageSegmenter::PreInit() {
199 SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
200 return absl::OkStatus();
201 }
202
CheckAndSetOutputs()203 absl::Status ImageSegmenter::CheckAndSetOutputs() {
204 // First, sanity checks on the model itself.
205 const TfLiteEngine::Interpreter* interpreter = engine_->interpreter();
206
207 // Check the number of output tensors.
208 if (TfLiteEngine::OutputCount(interpreter) != 1) {
209 return CreateStatusWithPayload(
210 StatusCode::kInvalidArgument,
211 absl::StrFormat("Image segmentation models are expected to have only 1 "
212 "output, found %d",
213 TfLiteEngine::OutputCount(interpreter)),
214 TfLiteSupportStatus::kInvalidNumOutputTensorsError);
215 }
216 const TfLiteTensor* output_tensor = TfLiteEngine::GetOutput(interpreter, 0);
217
218 // Check tensor dimensions.
219 if (output_tensor->dims->size != 4) {
220 return CreateStatusWithPayload(
221 StatusCode::kInvalidArgument,
222 absl::StrFormat(
223 "Output tensor is expected to have 4 dimensions, found %d.",
224 output_tensor->dims->size),
225 TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
226 }
227 if (output_tensor->dims->data[0] != 1) {
228 return CreateStatusWithPayload(
229 StatusCode::kInvalidArgument,
230 absl::StrFormat("Expected batch size of 1, found %d.",
231 output_tensor->dims->data[0]),
232 TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
233 }
234 output_height_ = output_tensor->dims->data[1];
235 output_width_ = output_tensor->dims->data[2];
236 output_depth_ = output_tensor->dims->data[3];
237 if (output_depth_ > kMaxNumClasses) {
238 return CreateStatusWithPayload(
239 StatusCode::kInvalidArgument,
240 absl::StrFormat("Expected at most %d output classes, found %d",
241 kMaxNumClasses, output_depth_),
242 TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
243 }
244
245 // Check tensor type.
246 if (output_tensor->type != kTfLiteFloat32 &&
247 output_tensor->type != kTfLiteUInt8) {
248 return CreateStatusWithPayload(
249 StatusCode::kInvalidArgument,
250 absl::StrFormat("Type mismatch for output tensor. Requested one of "
251 "these types: kTfLiteUint8/kTfLiteFloat32, got %s.",
252 TfLiteTypeGetName(output_tensor->type)),
253 TfLiteSupportStatus::kInvalidOutputTensorTypeError);
254 }
255 has_uint8_outputs_ = (output_tensor->type == kTfLiteUInt8);
256
257 // Build label map from metadata, if available.
258 const ModelMetadataExtractor* metadata_extractor =
259 engine_->metadata_extractor();
260 const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
261 output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata();
262 if (output_tensor_metadata != nullptr) {
263 // Check metadata consistency.
264 if (output_tensor_metadata->size() != 1) {
265 return CreateStatusWithPayload(
266 StatusCode::kInvalidArgument,
267 absl::StrFormat("Mismatch between number of output tensors (1) and "
268 "output tensors metadata (%d).",
269 output_tensor_metadata->size()),
270 TfLiteSupportStatus::kMetadataInconsistencyError);
271 }
272 ASSIGN_OR_RETURN(
273 label_map_,
274 GetLabelMapIfAny(*metadata_extractor, *output_tensor_metadata->Get(0),
275 options_->display_names_locale()));
276 }
277
278 // If label map is still empty, build a default one.
279 if (label_map_.empty()) {
280 for (int class_index = 0; class_index < output_depth_; ++class_index) {
281 label_map_.emplace_back(LabelMapItem{});
282 }
283 }
284
285 return absl::OkStatus();
286 }
287
InitColoredLabels()288 absl::Status ImageSegmenter::InitColoredLabels() {
289 for (int i = 0; i < label_map_.size(); ++i) {
290 Segmentation::ColoredLabel colored_label;
291 colored_label.set_r(kColorMap[3 * i]);
292 colored_label.set_g(kColorMap[3 * i + 1]);
293 colored_label.set_b(kColorMap[3 * i + 2]);
294 const LabelMapItem& item = label_map_[i];
295 if (!item.name.empty()) {
296 colored_label.set_class_name(item.name);
297 }
298 if (!item.display_name.empty()) {
299 colored_label.set_display_name(item.display_name);
300 }
301 colored_labels_.push_back(colored_label);
302 }
303 return absl::OkStatus();
304 }
305
Segment(const FrameBuffer & frame_buffer)306 StatusOr<SegmentationResult> ImageSegmenter::Segment(
307 const FrameBuffer& frame_buffer) {
308 BoundingBox roi;
309 roi.set_width(frame_buffer.dimension().width);
310 roi.set_height(frame_buffer.dimension().height);
311 return InferWithFallback(frame_buffer, roi);
312 }
313
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const FrameBuffer & frame_buffer,const BoundingBox &)314 StatusOr<SegmentationResult> ImageSegmenter::Postprocess(
315 const std::vector<const TfLiteTensor*>& output_tensors,
316 const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
317 if (output_tensors.size() != 1) {
318 return CreateStatusWithPayload(
319 StatusCode::kInternal,
320 absl::StrFormat("Expected 1 output tensors, found %d",
321 output_tensors.size()));
322 }
323 const TfLiteTensor* output_tensor = output_tensors[0];
324
325 SegmentationResult result;
326 Segmentation* segmentation = result.add_segmentation();
327 *segmentation->mutable_colored_labels() = {colored_labels_.begin(),
328 colored_labels_.end()};
329
330 // The output tensor has orientation `frame_buffer.orientation()`, as it has
331 // been produced from the pre-processed frame.
332 FrameBuffer::Orientation tensor_orientation = frame_buffer.orientation();
333 // The output tensor always has size `output_width_ x output_height_`
334 FrameBuffer::Dimension tensor_dimension = {output_width_, output_height_};
335
336 // The masks to produce from the output tensor need to be re-oriented in the
337 // unrotated frame of reference coordinates system, i.e. kTopLeft.
338 FrameBuffer::Orientation mask_orientation =
339 FrameBuffer::Orientation::kTopLeft;
340 // They may thus have swapped dimensions compared to the tensor if the
341 // rotation is 90° or 270°.
342 FrameBuffer::Dimension mask_dimension(tensor_dimension);
343 if (RequireDimensionSwap(frame_buffer.orientation(),
344 FrameBuffer::Orientation::kTopLeft)) {
345 mask_dimension.Swap();
346 }
347 segmentation->set_width(mask_dimension.width);
348 segmentation->set_height(mask_dimension.height);
349
350 // XY coordinates in the tensor, to be computed from mask_x and mask_y below.
351 int tensor_x;
352 int tensor_y;
353
354 if (options_->output_type() == ImageSegmenterOptions::CATEGORY_MASK) {
355 auto* category_mask = segmentation->mutable_category_mask();
356 category_mask->resize(mask_dimension.width * mask_dimension.height);
357 int pixel_offset = 0;
358 for (int mask_y = 0; mask_y < mask_dimension.height; ++mask_y) {
359 for (int mask_x = 0; mask_x < mask_dimension.width; ++mask_x) {
360 // Compute the coordinates (tensor_x, tensor_y) in the tensor with
361 // tensor_orientation = frame_buffer.orientation() corresponding to the
362 // coordinates (mask_x, mask_y) in the mask being filled with
363 // mask_orientation = kTopLeft, i.e. the orientation of the unrotated
364 // frame of reference.
365 OrientCoordinates(/*from_x=*/mask_x,
366 /*from_y=*/mask_y,
367 /*from_orientation=*/mask_orientation,
368 /*to_orientation=*/tensor_orientation,
369 /*from_dimension=*/mask_dimension,
370 /*to_x=*/&tensor_x,
371 /*to_y=*/&tensor_y);
372 int class_index = 0;
373 float max_confidence = 0.0f;
374 for (int d = 0; d < output_depth_; ++d) {
375 const float confidence =
376 GetOutputConfidence(*output_tensor, tensor_x, tensor_y, d);
377 if (confidence > max_confidence) {
378 class_index = d;
379 max_confidence = confidence;
380 }
381 }
382 (*category_mask)[pixel_offset++] = static_cast<char>(class_index);
383 }
384 }
385 } else if (options_->output_type() ==
386 ImageSegmenterOptions::CONFIDENCE_MASK) {
387 auto* confidence_masks = segmentation->mutable_confidence_masks();
388 for (int d = 0; d < output_depth_; ++d) {
389 confidence_masks->add_confidence_mask();
390 }
391 for (int mask_y = 0; mask_y < segmentation->height(); ++mask_y) {
392 for (int mask_x = 0; mask_x < segmentation->width(); ++mask_x) {
393 // See above.
394 OrientCoordinates(/*from_x=*/mask_x,
395 /*from_y=*/mask_y,
396 /*from_orientation=*/mask_orientation,
397 /*to_orientation=*/tensor_orientation,
398 /*from_dimension=*/mask_dimension,
399 /*to_x=*/&tensor_x,
400 /*to_y=*/&tensor_y);
401 for (int d = 0; d < output_depth_; ++d) {
402 confidence_masks->mutable_confidence_mask(d)->add_value(
403 GetOutputConfidence(*output_tensor, tensor_x, tensor_y, d));
404 }
405 }
406 }
407 }
408
409 return result;
410 }
411
GetOutputConfidence(const TfLiteTensor & output_tensor,int x,int y,int depth)412 float ImageSegmenter::GetOutputConfidence(const TfLiteTensor& output_tensor,
413 int x, int y, int depth) {
414 int index = output_width_ * output_depth_ * y + output_depth_ * x + depth;
415 if (has_uint8_outputs_) {
416 const uint8* data = AssertAndReturnTypedTensor<uint8>(&output_tensor);
417 return output_tensor.params.scale *
418 (static_cast<int>(data[index]) - output_tensor.params.zero_point);
419 } else {
420 const float* data = AssertAndReturnTypedTensor<float>(&output_tensor);
421 return data[index];
422 }
423 }
424
425 } // namespace vision
426 } // namespace task
427 } // namespace tflite
428