• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h"
16 
17 #include "absl/status/status.h"
18 #include "tensorflow_lite_support/cc/common.h"
19 #include "tensorflow_lite_support/cc/port/integral_types.h"
20 #include "tensorflow_lite_support/cc/port/status_macros.h"
21 #include "tensorflow_lite_support/cc/port/statusor.h"
22 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
23 
24 namespace tflite {
25 namespace task {
26 namespace vision {
27 namespace {
28 
29 using ::absl::StatusCode;
30 using ::tflite::ColorSpaceType_RGB;
31 using ::tflite::ContentProperties;
32 using ::tflite::ContentProperties_ImageProperties;
33 using ::tflite::EnumNameContentProperties;
34 using ::tflite::ImageProperties;
35 using ::tflite::TensorMetadata;
36 using ::tflite::metadata::ModelMetadataExtractor;
37 using ::tflite::support::CreateStatusWithPayload;
38 using ::tflite::support::StatusOr;
39 using ::tflite::support::TfLiteSupportStatus;
40 using ::tflite::task::core::TfLiteEngine;
41 
GetInputTensorMetadataIfAny(const ModelMetadataExtractor & metadata_extractor)42 StatusOr<const TensorMetadata*> GetInputTensorMetadataIfAny(
43     const ModelMetadataExtractor& metadata_extractor) {
44   if (metadata_extractor.GetModelMetadata() == nullptr ||
45       metadata_extractor.GetModelMetadata()->subgraph_metadata() == nullptr) {
46     // Some models have no metadata at all (or very partial), so exit early.
47     return nullptr;
48   } else if (metadata_extractor.GetInputTensorCount() != 1) {
49     return CreateStatusWithPayload(
50         StatusCode::kInvalidArgument,
51         "Models are assumed to have a single input TensorMetadata.",
52         TfLiteSupportStatus::kInvalidNumInputTensorsError);
53   }
54 
55   const TensorMetadata* metadata = metadata_extractor.GetInputTensorMetadata(0);
56 
57   if (metadata == nullptr) {
58     // Should never happen.
59     return CreateStatusWithPayload(StatusCode::kInternal,
60                                    "Input TensorMetadata is null.");
61   }
62 
63   return metadata;
64 }
65 
GetImagePropertiesIfAny(const TensorMetadata & tensor_metadata)66 StatusOr<const ImageProperties*> GetImagePropertiesIfAny(
67     const TensorMetadata& tensor_metadata) {
68   if (tensor_metadata.content() == nullptr ||
69       tensor_metadata.content()->content_properties() == nullptr) {
70     return nullptr;
71   }
72 
73   ContentProperties type = tensor_metadata.content()->content_properties_type();
74 
75   if (type != ContentProperties_ImageProperties) {
76     return CreateStatusWithPayload(
77         StatusCode::kInvalidArgument,
78         absl::StrCat(
79             "Expected ImageProperties for tensor ",
80             tensor_metadata.name() ? tensor_metadata.name()->str() : "#0",
81             ", got ", EnumNameContentProperties(type), "."),
82         TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
83   }
84 
85   return tensor_metadata.content()->content_properties_as_ImageProperties();
86 }
87 
GetNormalizationOptionsIfAny(const TensorMetadata & tensor_metadata)88 StatusOr<absl::optional<NormalizationOptions>> GetNormalizationOptionsIfAny(
89     const TensorMetadata& tensor_metadata) {
90   ASSIGN_OR_RETURN(
91       const tflite::ProcessUnit* normalization_process_unit,
92       ModelMetadataExtractor::FindFirstProcessUnit(
93           tensor_metadata, tflite::ProcessUnitOptions_NormalizationOptions));
94   if (normalization_process_unit == nullptr) {
95     return {absl::nullopt};
96   }
97   const tflite::NormalizationOptions* tf_normalization_options =
98       normalization_process_unit->options_as_NormalizationOptions();
99   const auto mean_values = tf_normalization_options->mean();
100   const auto std_values = tf_normalization_options->std();
101   if (mean_values->size() != std_values->size()) {
102     return CreateStatusWithPayload(
103         StatusCode::kInvalidArgument,
104         absl::StrCat("NormalizationOptions: expected mean and std of same "
105                      "dimension, got ",
106                      mean_values->size(), " and ", std_values->size(), "."),
107         TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
108   }
109   absl::optional<NormalizationOptions> normalization_options;
110   if (mean_values->size() == 1) {
111     normalization_options = NormalizationOptions{
112         .mean_values = {mean_values->Get(0), mean_values->Get(0),
113                         mean_values->Get(0)},
114         .std_values = {std_values->Get(0), std_values->Get(0),
115                        std_values->Get(0)},
116         .num_values = 1};
117   } else if (mean_values->size() == 3) {
118     normalization_options = NormalizationOptions{
119         .mean_values = {mean_values->Get(0), mean_values->Get(1),
120                         mean_values->Get(2)},
121         .std_values = {std_values->Get(0), std_values->Get(1),
122                        std_values->Get(2)},
123         .num_values = 3};
124   } else {
125     return CreateStatusWithPayload(
126         StatusCode::kInvalidArgument,
127         absl::StrCat("NormalizationOptions: only 1 or 3 mean and std "
128                      "values are supported, got ",
129                      mean_values->size(), "."),
130         TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
131   }
132   return normalization_options;
133 }
134 
135 }  // namespace
136 
BuildInputImageTensorSpecs(const TfLiteEngine::Interpreter & interpreter,const tflite::metadata::ModelMetadataExtractor & metadata_extractor)137 StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
138     const TfLiteEngine::Interpreter& interpreter,
139     const tflite::metadata::ModelMetadataExtractor& metadata_extractor) {
140   ASSIGN_OR_RETURN(const TensorMetadata* metadata,
141                    GetInputTensorMetadataIfAny(metadata_extractor));
142 
143   const ImageProperties* props = nullptr;
144   absl::optional<NormalizationOptions> normalization_options;
145   if (metadata != nullptr) {
146     ASSIGN_OR_RETURN(props, GetImagePropertiesIfAny(*metadata));
147     ASSIGN_OR_RETURN(normalization_options,
148                      GetNormalizationOptionsIfAny(*metadata));
149   }
150 
151   if (TfLiteEngine::InputCount(&interpreter) != 1) {
152     return CreateStatusWithPayload(
153         StatusCode::kInvalidArgument,
154         "Models are assumed to have a single input.",
155         TfLiteSupportStatus::kInvalidNumInputTensorsError);
156   }
157 
158   // Input-related specifications.
159   const TfLiteTensor* input_tensor = TfLiteEngine::GetInput(&interpreter, 0);
160   if (input_tensor->dims->size != 4) {
161     return CreateStatusWithPayload(
162         StatusCode::kInvalidArgument,
163         "Only 4D tensors in BHWD layout are supported.",
164         TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
165   }
166   static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteFloat32};
167   TfLiteType input_type = input_tensor->type;
168   if (!absl::c_linear_search(valid_types, input_type)) {
169     return CreateStatusWithPayload(
170         StatusCode::kInvalidArgument,
171         absl::StrCat(
172             "Type mismatch for input tensor ", input_tensor->name,
173             ". Requested one of these types: kTfLiteUint8/kTfLiteFloat32, got ",
174             TfLiteTypeGetName(input_type), "."),
175         TfLiteSupportStatus::kInvalidInputTensorTypeError);
176   }
177 
178   // The expected layout is BHWD, i.e. batch x height x width x color
179   // See https://www.tensorflow.org/guide/tensors
180   const int batch = input_tensor->dims->data[0];
181   const int height = input_tensor->dims->data[1];
182   const int width = input_tensor->dims->data[2];
183   const int depth = input_tensor->dims->data[3];
184 
185   if (props != nullptr && props->color_space() != ColorSpaceType_RGB) {
186     return CreateStatusWithPayload(StatusCode::kInvalidArgument,
187                                    "Only RGB color space is supported for now.",
188                                    TfLiteSupportStatus::kInvalidArgumentError);
189   }
190   if (batch != 1 || depth != 3) {
191     return CreateStatusWithPayload(
192         StatusCode::kInvalidArgument,
193         absl::StrCat("The input tensor should have dimensions 1 x height x "
194                      "width x 3. Got ",
195                      batch, " x ", height, " x ", width, " x ", depth, "."),
196         TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
197   }
198   int bytes_size = input_tensor->bytes;
199   size_t byte_depth =
200       input_type == kTfLiteFloat32 ? sizeof(float) : sizeof(uint8);
201 
202   // Sanity checks.
203   if (input_type == kTfLiteFloat32) {
204     if (!normalization_options.has_value()) {
205       return CreateStatusWithPayload(
206           absl::StatusCode::kNotFound,
207           "Input tensor has type kTfLiteFloat32: it requires specifying "
208           "NormalizationOptions metadata to preprocess input images.",
209           TfLiteSupportStatus::kMetadataMissingNormalizationOptionsError);
210     } else if (bytes_size / sizeof(float) %
211                    normalization_options.value().num_values !=
212                0) {
213       return CreateStatusWithPayload(
214           StatusCode::kInvalidArgument,
215           "The number of elements in the input tensor must be a multiple of "
216           "the number of normalization parameters.",
217           TfLiteSupportStatus::kInvalidArgumentError);
218     }
219   }
220   if (width <= 0) {
221     return CreateStatusWithPayload(
222         StatusCode::kInvalidArgument, "The input width should be positive.",
223         TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
224   }
225   if (height <= 0) {
226     return CreateStatusWithPayload(
227         StatusCode::kInvalidArgument, "The input height should be positive.",
228         TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
229   }
230   if (bytes_size != height * width * depth * byte_depth) {
231     return CreateStatusWithPayload(
232         StatusCode::kInvalidArgument,
233         "The input size in bytes does not correspond to the expected number of "
234         "pixels.",
235         TfLiteSupportStatus::kInvalidInputTensorSizeError);
236   }
237 
238   // Note: in the future, additional checks against `props->default_size()`
239   // might be added. Also, verify that NormalizationOptions, if any, do specify
240   // a single value when color space is grayscale.
241 
242   ImageTensorSpecs result;
243   result.image_width = width;
244   result.image_height = height;
245   result.color_space = ColorSpaceType_RGB;
246   result.tensor_type = input_type;
247   result.normalization_options = normalization_options;
248 
249   return result;
250 }
251 
252 }  // namespace vision
253 }  // namespace task
254 }  // namespace tflite
255