• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/vision/image_segmenter.h"
17 
18 #include <algorithm>
19 
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_format.h"
22 #include "absl/strings/string_view.h"
23 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow_lite_support/cc/common.h"
26 #include "tensorflow_lite_support/cc/port/integral_types.h"
27 #include "tensorflow_lite_support/cc/port/status_macros.h"
28 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
29 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
30 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
31 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
32 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
33 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
34 
35 namespace tflite {
36 namespace task {
37 namespace vision {
38 
39 namespace {
40 
41 using ::absl::StatusCode;
42 using ::tflite::TensorMetadata;
43 using ::tflite::metadata::ModelMetadataExtractor;
44 using ::tflite::support::CreateStatusWithPayload;
45 using ::tflite::support::StatusOr;
46 using ::tflite::support::TfLiteSupportStatus;
47 using ::tflite::task::core::AssertAndReturnTypedTensor;
48 using ::tflite::task::core::TaskAPIFactory;
49 using ::tflite::task::core::TfLiteEngine;
50 
51 // The maximum number of labels allowed in the labelmap. This is because so far
52 // segmentation masks are stored with 8 bit per pixel (flattened byte array).
53 constexpr uint32 kMaxNumClasses = 256;
54 
55 // TODO(b/)
56 // The colormap used to fill `ColoredLabel`-s, as a flattened array of 256 {R,
57 // G, B} components.
58 constexpr uint8 kColorMap[768] = {
59     0,   0,   0,   128, 0,   0,   0,   128, 0,   128, 128, 0,   0,   0,   128,
60     128, 0,   128, 0,   128, 128, 128, 128, 128, 64,  0,   0,   192, 0,   0,
61     64,  128, 0,   192, 128, 0,   64,  0,   128, 192, 0,   128, 64,  128, 128,
62     192, 128, 128, 0,   64,  0,   128, 64,  0,   0,   192, 0,   128, 192, 0,
63     0,   64,  128, 128, 64,  128, 0,   192, 128, 128, 192, 128, 64,  64,  0,
64     192, 64,  0,   64,  192, 0,   192, 192, 0,   64,  64,  128, 192, 64,  128,
65     64,  192, 128, 192, 192, 128, 0,   0,   64,  128, 0,   64,  0,   128, 64,
66     128, 128, 64,  0,   0,   192, 128, 0,   192, 0,   128, 192, 128, 128, 192,
67     64,  0,   64,  192, 0,   64,  64,  128, 64,  192, 128, 64,  64,  0,   192,
68     192, 0,   192, 64,  128, 192, 192, 128, 192, 0,   64,  64,  128, 64,  64,
69     0,   192, 64,  128, 192, 64,  0,   64,  192, 128, 64,  192, 0,   192, 192,
70     128, 192, 192, 64,  64,  64,  192, 64,  64,  64,  192, 64,  192, 192, 64,
71     64,  64,  192, 192, 64,  192, 64,  192, 192, 192, 192, 192, 32,  0,   0,
72     160, 0,   0,   32,  128, 0,   160, 128, 0,   32,  0,   128, 160, 0,   128,
73     32,  128, 128, 160, 128, 128, 96,  0,   0,   224, 0,   0,   96,  128, 0,
74     224, 128, 0,   96,  0,   128, 224, 0,   128, 96,  128, 128, 224, 128, 128,
75     32,  64,  0,   160, 64,  0,   32,  192, 0,   160, 192, 0,   32,  64,  128,
76     160, 64,  128, 32,  192, 128, 160, 192, 128, 96,  64,  0,   224, 64,  0,
77     96,  192, 0,   224, 192, 0,   96,  64,  128, 224, 64,  128, 96,  192, 128,
78     224, 192, 128, 32,  0,   64,  160, 0,   64,  32,  128, 64,  160, 128, 64,
79     32,  0,   192, 160, 0,   192, 32,  128, 192, 160, 128, 192, 96,  0,   64,
80     224, 0,   64,  96,  128, 64,  224, 128, 64,  96,  0,   192, 224, 0,   192,
81     96,  128, 192, 224, 128, 192, 32,  64,  64,  160, 64,  64,  32,  192, 64,
82     160, 192, 64,  32,  64,  192, 160, 64,  192, 32,  192, 192, 160, 192, 192,
83     96,  64,  64,  224, 64,  64,  96,  192, 64,  224, 192, 64,  96,  64,  192,
84     224, 64,  192, 96,  192, 192, 224, 192, 192, 0,   32,  0,   128, 32,  0,
85     0,   160, 0,   128, 160, 0,   0,   32,  128, 128, 32,  128, 0,   160, 128,
86     128, 160, 128, 64,  32,  0,   192, 32,  0,   64,  160, 0,   192, 160, 0,
87     64,  32,  128, 192, 32,  128, 64,  160, 128, 192, 160, 128, 0,   96,  0,
88     128, 96,  0,   0,   224, 0,   128, 224, 0,   0,   96,  128, 128, 96,  128,
89     0,   224, 128, 128, 224, 128, 64,  96,  0,   192, 96,  0,   64,  224, 0,
90     192, 224, 0,   64,  96,  128, 192, 96,  128, 64,  224, 128, 192, 224, 128,
91     0,   32,  64,  128, 32,  64,  0,   160, 64,  128, 160, 64,  0,   32,  192,
92     128, 32,  192, 0,   160, 192, 128, 160, 192, 64,  32,  64,  192, 32,  64,
93     64,  160, 64,  192, 160, 64,  64,  32,  192, 192, 32,  192, 64,  160, 192,
94     192, 160, 192, 0,   96,  64,  128, 96,  64,  0,   224, 64,  128, 224, 64,
95     0,   96,  192, 128, 96,  192, 0,   224, 192, 128, 224, 192, 64,  96,  64,
96     192, 96,  64,  64,  224, 64,  192, 224, 64,  64,  96,  192, 192, 96,  192,
97     64,  224, 192, 192, 224, 192, 32,  32,  0,   160, 32,  0,   32,  160, 0,
98     160, 160, 0,   32,  32,  128, 160, 32,  128, 32,  160, 128, 160, 160, 128,
99     96,  32,  0,   224, 32,  0,   96,  160, 0,   224, 160, 0,   96,  32,  128,
100     224, 32,  128, 96,  160, 128, 224, 160, 128, 32,  96,  0,   160, 96,  0,
101     32,  224, 0,   160, 224, 0,   32,  96,  128, 160, 96,  128, 32,  224, 128,
102     160, 224, 128, 96,  96,  0,   224, 96,  0,   96,  224, 0,   224, 224, 0,
103     96,  96,  128, 224, 96,  128, 96,  224, 128, 224, 224, 128, 32,  32,  64,
104     160, 32,  64,  32,  160, 64,  160, 160, 64,  32,  32,  192, 160, 32,  192,
105     32,  160, 192, 160, 160, 192, 96,  32,  64,  224, 32,  64,  96,  160, 64,
106     224, 160, 64,  96,  32,  192, 224, 32,  192, 96,  160, 192, 224, 160, 192,
107     32,  96,  64,  160, 96,  64,  32,  224, 64,  160, 224, 64,  32,  96,  192,
108     160, 96,  192, 32,  224, 192, 160, 224, 192, 96,  96,  64,  224, 96,  64,
109     96,  224, 64,  224, 224, 64,  96,  96,  192, 224, 96,  192, 96,  224, 192,
110     224, 224, 192};
111 
GetLabelMapIfAny(const ModelMetadataExtractor & metadata_extractor,const TensorMetadata & tensor_metadata,absl::string_view locale)112 StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
113     const ModelMetadataExtractor& metadata_extractor,
114     const TensorMetadata& tensor_metadata, absl::string_view locale) {
115   const std::string labels_filename =
116       ModelMetadataExtractor::FindFirstAssociatedFileName(
117           tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS);
118   if (labels_filename.empty()) {
119     return std::vector<LabelMapItem>();
120   }
121   ASSIGN_OR_RETURN(absl::string_view labels_file,
122                    metadata_extractor.GetAssociatedFile(labels_filename));
123   const std::string display_names_filename =
124       ModelMetadataExtractor::FindFirstAssociatedFileName(
125           tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS,
126           locale);
127   absl::string_view display_names_file = nullptr;
128   if (!display_names_filename.empty()) {
129     ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
130                                              display_names_filename));
131   }
132   return BuildLabelMapFromFiles(labels_file, display_names_file);
133 }
134 
135 }  // namespace
136 
137 /* static */
SanityCheckOptions(const ImageSegmenterOptions & options)138 absl::Status ImageSegmenter::SanityCheckOptions(
139     const ImageSegmenterOptions& options) {
140   if (!options.has_model_file_with_metadata()) {
141     return CreateStatusWithPayload(
142         StatusCode::kInvalidArgument,
143         "Missing mandatory `model_file_with_metadata` field",
144         TfLiteSupportStatus::kInvalidArgumentError);
145   }
146   if (options.output_type() == ImageSegmenterOptions::UNSPECIFIED) {
147     return CreateStatusWithPayload(
148         StatusCode::kInvalidArgument,
149         "ImageSegmenterOptions: `output_type` must not be UNSPECIFIED",
150         TfLiteSupportStatus::kInvalidArgumentError);
151   }
152   if (options.num_threads() == 0 || options.num_threads() < -1) {
153     return CreateStatusWithPayload(
154         StatusCode::kInvalidArgument,
155         "`num_threads` must be greater than 0 or equal to -1.",
156         TfLiteSupportStatus::kInvalidArgumentError);
157   }
158   return absl::OkStatus();
159 }
160 
CreateFromOptions(const ImageSegmenterOptions & options,std::unique_ptr<tflite::OpResolver> resolver)161 StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::CreateFromOptions(
162     const ImageSegmenterOptions& options,
163     std::unique_ptr<tflite::OpResolver> resolver) {
164   RETURN_IF_ERROR(SanityCheckOptions(options));
165 
166   // Copy options to ensure the ExternalFile outlives the constructed object.
167   auto options_copy = absl::make_unique<ImageSegmenterOptions>(options);
168 
169   ASSIGN_OR_RETURN(auto image_segmenter,
170                    TaskAPIFactory::CreateFromExternalFileProto<ImageSegmenter>(
171                        &options_copy->model_file_with_metadata(),
172                        std::move(resolver), options_copy->num_threads()));
173 
174   RETURN_IF_ERROR(image_segmenter->Init(std::move(options_copy)));
175 
176   return image_segmenter;
177 }
178 
Init(std::unique_ptr<ImageSegmenterOptions> options)179 absl::Status ImageSegmenter::Init(
180     std::unique_ptr<ImageSegmenterOptions> options) {
181   // Set options.
182   options_ = std::move(options);
183 
184   // Perform pre-initialization actions (by default, sets the process engine for
185   // image pre-processing to kLibyuv as a sane default).
186   RETURN_IF_ERROR(PreInit());
187 
188   // Sanity check and set inputs and outputs.
189   RETURN_IF_ERROR(CheckAndSetInputs());
190   RETURN_IF_ERROR(CheckAndSetOutputs());
191 
192   // Initialize colored_labels_ once and for all.
193   RETURN_IF_ERROR(InitColoredLabels());
194 
195   return absl::OkStatus();
196 }
197 
PreInit()198 absl::Status ImageSegmenter::PreInit() {
199   SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
200   return absl::OkStatus();
201 }
202 
CheckAndSetOutputs()203 absl::Status ImageSegmenter::CheckAndSetOutputs() {
204   // First, sanity checks on the model itself.
205   const TfLiteEngine::Interpreter* interpreter = engine_->interpreter();
206 
207   // Check the number of output tensors.
208   if (TfLiteEngine::OutputCount(interpreter) != 1) {
209     return CreateStatusWithPayload(
210         StatusCode::kInvalidArgument,
211         absl::StrFormat("Image segmentation models are expected to have only 1 "
212                         "output, found %d",
213                         TfLiteEngine::OutputCount(interpreter)),
214         TfLiteSupportStatus::kInvalidNumOutputTensorsError);
215   }
216   const TfLiteTensor* output_tensor = TfLiteEngine::GetOutput(interpreter, 0);
217 
218   // Check tensor dimensions.
219   if (output_tensor->dims->size != 4) {
220     return CreateStatusWithPayload(
221         StatusCode::kInvalidArgument,
222         absl::StrFormat(
223             "Output tensor is expected to have 4 dimensions, found %d.",
224             output_tensor->dims->size),
225         TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
226   }
227   if (output_tensor->dims->data[0] != 1) {
228     return CreateStatusWithPayload(
229         StatusCode::kInvalidArgument,
230         absl::StrFormat("Expected batch size of 1, found %d.",
231                         output_tensor->dims->data[0]),
232         TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
233   }
234   output_height_ = output_tensor->dims->data[1];
235   output_width_ = output_tensor->dims->data[2];
236   output_depth_ = output_tensor->dims->data[3];
237   if (output_depth_ > kMaxNumClasses) {
238     return CreateStatusWithPayload(
239         StatusCode::kInvalidArgument,
240         absl::StrFormat("Expected at most %d output classes, found %d",
241                         kMaxNumClasses, output_depth_),
242         TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
243   }
244 
245   // Check tensor type.
246   if (output_tensor->type != kTfLiteFloat32 &&
247       output_tensor->type != kTfLiteUInt8) {
248     return CreateStatusWithPayload(
249         StatusCode::kInvalidArgument,
250         absl::StrFormat("Type mismatch for output tensor. Requested one of "
251                         "these types: kTfLiteUint8/kTfLiteFloat32, got %s.",
252                         TfLiteTypeGetName(output_tensor->type)),
253         TfLiteSupportStatus::kInvalidOutputTensorTypeError);
254   }
255   has_uint8_outputs_ = (output_tensor->type == kTfLiteUInt8);
256 
257   // Build label map from metadata, if available.
258   const ModelMetadataExtractor* metadata_extractor =
259       engine_->metadata_extractor();
260   const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
261       output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata();
262   if (output_tensor_metadata != nullptr) {
263     // Check metadata consistency.
264     if (output_tensor_metadata->size() != 1) {
265       return CreateStatusWithPayload(
266           StatusCode::kInvalidArgument,
267           absl::StrFormat("Mismatch between number of output tensors (1) and "
268                           "output tensors metadata (%d).",
269                           output_tensor_metadata->size()),
270           TfLiteSupportStatus::kMetadataInconsistencyError);
271     }
272     ASSIGN_OR_RETURN(
273         label_map_,
274         GetLabelMapIfAny(*metadata_extractor, *output_tensor_metadata->Get(0),
275                          options_->display_names_locale()));
276   }
277 
278   // If label map is still empty, build a default one.
279   if (label_map_.empty()) {
280     for (int class_index = 0; class_index < output_depth_; ++class_index) {
281       label_map_.emplace_back(LabelMapItem{});
282     }
283   }
284 
285   return absl::OkStatus();
286 }
287 
InitColoredLabels()288 absl::Status ImageSegmenter::InitColoredLabels() {
289   for (int i = 0; i < label_map_.size(); ++i) {
290     Segmentation::ColoredLabel colored_label;
291     colored_label.set_r(kColorMap[3 * i]);
292     colored_label.set_g(kColorMap[3 * i + 1]);
293     colored_label.set_b(kColorMap[3 * i + 2]);
294     const LabelMapItem& item = label_map_[i];
295     if (!item.name.empty()) {
296       colored_label.set_class_name(item.name);
297     }
298     if (!item.display_name.empty()) {
299       colored_label.set_display_name(item.display_name);
300     }
301     colored_labels_.push_back(colored_label);
302   }
303   return absl::OkStatus();
304 }
305 
Segment(const FrameBuffer & frame_buffer)306 StatusOr<SegmentationResult> ImageSegmenter::Segment(
307     const FrameBuffer& frame_buffer) {
308   BoundingBox roi;
309   roi.set_width(frame_buffer.dimension().width);
310   roi.set_height(frame_buffer.dimension().height);
311   return InferWithFallback(frame_buffer, roi);
312 }
313 
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const FrameBuffer & frame_buffer,const BoundingBox &)314 StatusOr<SegmentationResult> ImageSegmenter::Postprocess(
315     const std::vector<const TfLiteTensor*>& output_tensors,
316     const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
317   if (output_tensors.size() != 1) {
318     return CreateStatusWithPayload(
319         StatusCode::kInternal,
320         absl::StrFormat("Expected 1 output tensors, found %d",
321                         output_tensors.size()));
322   }
323   const TfLiteTensor* output_tensor = output_tensors[0];
324 
325   SegmentationResult result;
326   Segmentation* segmentation = result.add_segmentation();
327   *segmentation->mutable_colored_labels() = {colored_labels_.begin(),
328                                              colored_labels_.end()};
329 
330   // The output tensor has orientation `frame_buffer.orientation()`, as it has
331   // been produced from the pre-processed frame.
332   FrameBuffer::Orientation tensor_orientation = frame_buffer.orientation();
333   // The output tensor always has size `output_width_ x output_height_`
334   FrameBuffer::Dimension tensor_dimension = {output_width_, output_height_};
335 
336   // The masks to produce from the output tensor need to be re-oriented in the
337   // unrotated frame of reference coordinates system, i.e. kTopLeft.
338   FrameBuffer::Orientation mask_orientation =
339       FrameBuffer::Orientation::kTopLeft;
340   // They may thus have swapped dimensions compared to the tensor if the
341   // rotation is 90° or 270°.
342   FrameBuffer::Dimension mask_dimension(tensor_dimension);
343   if (RequireDimensionSwap(frame_buffer.orientation(),
344                            FrameBuffer::Orientation::kTopLeft)) {
345     mask_dimension.Swap();
346   }
347   segmentation->set_width(mask_dimension.width);
348   segmentation->set_height(mask_dimension.height);
349 
350   // XY coordinates in the tensor, to be computed from mask_x and mask_y below.
351   int tensor_x;
352   int tensor_y;
353 
354   if (options_->output_type() == ImageSegmenterOptions::CATEGORY_MASK) {
355     auto* category_mask = segmentation->mutable_category_mask();
356     category_mask->resize(mask_dimension.width * mask_dimension.height);
357     int pixel_offset = 0;
358     for (int mask_y = 0; mask_y < mask_dimension.height; ++mask_y) {
359       for (int mask_x = 0; mask_x < mask_dimension.width; ++mask_x) {
360         // Compute the coordinates (tensor_x, tensor_y) in the tensor with
361         // tensor_orientation = frame_buffer.orientation() corresponding to the
362         // coordinates (mask_x, mask_y) in the mask being filled with
363         // mask_orientation = kTopLeft, i.e. the orientation of the unrotated
364         // frame of reference.
365         OrientCoordinates(/*from_x=*/mask_x,
366                           /*from_y=*/mask_y,
367                           /*from_orientation=*/mask_orientation,
368                           /*to_orientation=*/tensor_orientation,
369                           /*from_dimension=*/mask_dimension,
370                           /*to_x=*/&tensor_x,
371                           /*to_y=*/&tensor_y);
372         int class_index = 0;
373         float max_confidence = 0.0f;
374         for (int d = 0; d < output_depth_; ++d) {
375           const float confidence =
376               GetOutputConfidence(*output_tensor, tensor_x, tensor_y, d);
377           if (confidence > max_confidence) {
378             class_index = d;
379             max_confidence = confidence;
380           }
381         }
382         (*category_mask)[pixel_offset++] = static_cast<char>(class_index);
383       }
384     }
385   } else if (options_->output_type() ==
386              ImageSegmenterOptions::CONFIDENCE_MASK) {
387     auto* confidence_masks = segmentation->mutable_confidence_masks();
388     for (int d = 0; d < output_depth_; ++d) {
389       confidence_masks->add_confidence_mask();
390     }
391     for (int mask_y = 0; mask_y < segmentation->height(); ++mask_y) {
392       for (int mask_x = 0; mask_x < segmentation->width(); ++mask_x) {
393         // See above.
394         OrientCoordinates(/*from_x=*/mask_x,
395                           /*from_y=*/mask_y,
396                           /*from_orientation=*/mask_orientation,
397                           /*to_orientation=*/tensor_orientation,
398                           /*from_dimension=*/mask_dimension,
399                           /*to_x=*/&tensor_x,
400                           /*to_y=*/&tensor_y);
401         for (int d = 0; d < output_depth_; ++d) {
402           confidence_masks->mutable_confidence_mask(d)->add_value(
403               GetOutputConfidence(*output_tensor, tensor_x, tensor_y, d));
404         }
405       }
406     }
407   }
408 
409   return result;
410 }
411 
GetOutputConfidence(const TfLiteTensor & output_tensor,int x,int y,int depth)412 float ImageSegmenter::GetOutputConfidence(const TfLiteTensor& output_tensor,
413                                           int x, int y, int depth) {
414   int index = output_width_ * output_depth_ * y + output_depth_ * x + depth;
415   if (has_uint8_outputs_) {
416     const uint8* data = AssertAndReturnTypedTensor<uint8>(&output_tensor);
417     return output_tensor.params.scale *
418            (static_cast<int>(data[index]) - output_tensor.params.zero_point);
419   } else {
420     const float* data = AssertAndReturnTypedTensor<float>(&output_tensor);
421     return data[index];
422   }
423 }
424 
425 }  // namespace vision
426 }  // namespace task
427 }  // namespace tflite
428