1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h"
16
17 #include "absl/status/status.h"
18 #include "tensorflow_lite_support/cc/common.h"
19 #include "tensorflow_lite_support/cc/port/integral_types.h"
20 #include "tensorflow_lite_support/cc/port/status_macros.h"
21 #include "tensorflow_lite_support/cc/port/statusor.h"
22 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
23
24 namespace tflite {
25 namespace task {
26 namespace vision {
27 namespace {
28
29 using ::absl::StatusCode;
30 using ::tflite::ColorSpaceType_RGB;
31 using ::tflite::ContentProperties;
32 using ::tflite::ContentProperties_ImageProperties;
33 using ::tflite::EnumNameContentProperties;
34 using ::tflite::ImageProperties;
35 using ::tflite::TensorMetadata;
36 using ::tflite::metadata::ModelMetadataExtractor;
37 using ::tflite::support::CreateStatusWithPayload;
38 using ::tflite::support::StatusOr;
39 using ::tflite::support::TfLiteSupportStatus;
40 using ::tflite::task::core::TfLiteEngine;
41
GetInputTensorMetadataIfAny(const ModelMetadataExtractor & metadata_extractor)42 StatusOr<const TensorMetadata*> GetInputTensorMetadataIfAny(
43 const ModelMetadataExtractor& metadata_extractor) {
44 if (metadata_extractor.GetModelMetadata() == nullptr ||
45 metadata_extractor.GetModelMetadata()->subgraph_metadata() == nullptr) {
46 // Some models have no metadata at all (or very partial), so exit early.
47 return nullptr;
48 } else if (metadata_extractor.GetInputTensorCount() != 1) {
49 return CreateStatusWithPayload(
50 StatusCode::kInvalidArgument,
51 "Models are assumed to have a single input TensorMetadata.",
52 TfLiteSupportStatus::kInvalidNumInputTensorsError);
53 }
54
55 const TensorMetadata* metadata = metadata_extractor.GetInputTensorMetadata(0);
56
57 if (metadata == nullptr) {
58 // Should never happen.
59 return CreateStatusWithPayload(StatusCode::kInternal,
60 "Input TensorMetadata is null.");
61 }
62
63 return metadata;
64 }
65
GetImagePropertiesIfAny(const TensorMetadata & tensor_metadata)66 StatusOr<const ImageProperties*> GetImagePropertiesIfAny(
67 const TensorMetadata& tensor_metadata) {
68 if (tensor_metadata.content() == nullptr ||
69 tensor_metadata.content()->content_properties() == nullptr) {
70 return nullptr;
71 }
72
73 ContentProperties type = tensor_metadata.content()->content_properties_type();
74
75 if (type != ContentProperties_ImageProperties) {
76 return CreateStatusWithPayload(
77 StatusCode::kInvalidArgument,
78 absl::StrCat(
79 "Expected ImageProperties for tensor ",
80 tensor_metadata.name() ? tensor_metadata.name()->str() : "#0",
81 ", got ", EnumNameContentProperties(type), "."),
82 TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
83 }
84
85 return tensor_metadata.content()->content_properties_as_ImageProperties();
86 }
87
GetNormalizationOptionsIfAny(const TensorMetadata & tensor_metadata)88 StatusOr<absl::optional<NormalizationOptions>> GetNormalizationOptionsIfAny(
89 const TensorMetadata& tensor_metadata) {
90 ASSIGN_OR_RETURN(
91 const tflite::ProcessUnit* normalization_process_unit,
92 ModelMetadataExtractor::FindFirstProcessUnit(
93 tensor_metadata, tflite::ProcessUnitOptions_NormalizationOptions));
94 if (normalization_process_unit == nullptr) {
95 return {absl::nullopt};
96 }
97 const tflite::NormalizationOptions* tf_normalization_options =
98 normalization_process_unit->options_as_NormalizationOptions();
99 const auto mean_values = tf_normalization_options->mean();
100 const auto std_values = tf_normalization_options->std();
101 if (mean_values->size() != std_values->size()) {
102 return CreateStatusWithPayload(
103 StatusCode::kInvalidArgument,
104 absl::StrCat("NormalizationOptions: expected mean and std of same "
105 "dimension, got ",
106 mean_values->size(), " and ", std_values->size(), "."),
107 TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
108 }
109 absl::optional<NormalizationOptions> normalization_options;
110 if (mean_values->size() == 1) {
111 normalization_options = NormalizationOptions{
112 .mean_values = {mean_values->Get(0), mean_values->Get(0),
113 mean_values->Get(0)},
114 .std_values = {std_values->Get(0), std_values->Get(0),
115 std_values->Get(0)},
116 .num_values = 1};
117 } else if (mean_values->size() == 3) {
118 normalization_options = NormalizationOptions{
119 .mean_values = {mean_values->Get(0), mean_values->Get(1),
120 mean_values->Get(2)},
121 .std_values = {std_values->Get(0), std_values->Get(1),
122 std_values->Get(2)},
123 .num_values = 3};
124 } else {
125 return CreateStatusWithPayload(
126 StatusCode::kInvalidArgument,
127 absl::StrCat("NormalizationOptions: only 1 or 3 mean and std "
128 "values are supported, got ",
129 mean_values->size(), "."),
130 TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
131 }
132 return normalization_options;
133 }
134
135 } // namespace
136
BuildInputImageTensorSpecs(const TfLiteEngine::Interpreter & interpreter,const tflite::metadata::ModelMetadataExtractor & metadata_extractor)137 StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
138 const TfLiteEngine::Interpreter& interpreter,
139 const tflite::metadata::ModelMetadataExtractor& metadata_extractor) {
140 ASSIGN_OR_RETURN(const TensorMetadata* metadata,
141 GetInputTensorMetadataIfAny(metadata_extractor));
142
143 const ImageProperties* props = nullptr;
144 absl::optional<NormalizationOptions> normalization_options;
145 if (metadata != nullptr) {
146 ASSIGN_OR_RETURN(props, GetImagePropertiesIfAny(*metadata));
147 ASSIGN_OR_RETURN(normalization_options,
148 GetNormalizationOptionsIfAny(*metadata));
149 }
150
151 if (TfLiteEngine::InputCount(&interpreter) != 1) {
152 return CreateStatusWithPayload(
153 StatusCode::kInvalidArgument,
154 "Models are assumed to have a single input.",
155 TfLiteSupportStatus::kInvalidNumInputTensorsError);
156 }
157
158 // Input-related specifications.
159 const TfLiteTensor* input_tensor = TfLiteEngine::GetInput(&interpreter, 0);
160 if (input_tensor->dims->size != 4) {
161 return CreateStatusWithPayload(
162 StatusCode::kInvalidArgument,
163 "Only 4D tensors in BHWD layout are supported.",
164 TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
165 }
166 static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteFloat32};
167 TfLiteType input_type = input_tensor->type;
168 if (!absl::c_linear_search(valid_types, input_type)) {
169 return CreateStatusWithPayload(
170 StatusCode::kInvalidArgument,
171 absl::StrCat(
172 "Type mismatch for input tensor ", input_tensor->name,
173 ". Requested one of these types: kTfLiteUint8/kTfLiteFloat32, got ",
174 TfLiteTypeGetName(input_type), "."),
175 TfLiteSupportStatus::kInvalidInputTensorTypeError);
176 }
177
178 // The expected layout is BHWD, i.e. batch x height x width x color
179 // See https://www.tensorflow.org/guide/tensors
180 const int batch = input_tensor->dims->data[0];
181 const int height = input_tensor->dims->data[1];
182 const int width = input_tensor->dims->data[2];
183 const int depth = input_tensor->dims->data[3];
184
185 if (props != nullptr && props->color_space() != ColorSpaceType_RGB) {
186 return CreateStatusWithPayload(StatusCode::kInvalidArgument,
187 "Only RGB color space is supported for now.",
188 TfLiteSupportStatus::kInvalidArgumentError);
189 }
190 if (batch != 1 || depth != 3) {
191 return CreateStatusWithPayload(
192 StatusCode::kInvalidArgument,
193 absl::StrCat("The input tensor should have dimensions 1 x height x "
194 "width x 3. Got ",
195 batch, " x ", height, " x ", width, " x ", depth, "."),
196 TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
197 }
198 int bytes_size = input_tensor->bytes;
199 size_t byte_depth =
200 input_type == kTfLiteFloat32 ? sizeof(float) : sizeof(uint8);
201
202 // Sanity checks.
203 if (input_type == kTfLiteFloat32) {
204 if (!normalization_options.has_value()) {
205 return CreateStatusWithPayload(
206 absl::StatusCode::kNotFound,
207 "Input tensor has type kTfLiteFloat32: it requires specifying "
208 "NormalizationOptions metadata to preprocess input images.",
209 TfLiteSupportStatus::kMetadataMissingNormalizationOptionsError);
210 } else if (bytes_size / sizeof(float) %
211 normalization_options.value().num_values !=
212 0) {
213 return CreateStatusWithPayload(
214 StatusCode::kInvalidArgument,
215 "The number of elements in the input tensor must be a multiple of "
216 "the number of normalization parameters.",
217 TfLiteSupportStatus::kInvalidArgumentError);
218 }
219 }
220 if (width <= 0) {
221 return CreateStatusWithPayload(
222 StatusCode::kInvalidArgument, "The input width should be positive.",
223 TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
224 }
225 if (height <= 0) {
226 return CreateStatusWithPayload(
227 StatusCode::kInvalidArgument, "The input height should be positive.",
228 TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
229 }
230 if (bytes_size != height * width * depth * byte_depth) {
231 return CreateStatusWithPayload(
232 StatusCode::kInvalidArgument,
233 "The input size in bytes does not correspond to the expected number of "
234 "pixels.",
235 TfLiteSupportStatus::kInvalidInputTensorSizeError);
236 }
237
238 // Note: in the future, additional checks against `props->default_size()`
239 // might be added. Also, verify that NormalizationOptions, if any, do specify
240 // a single value when color space is grayscale.
241
242 ImageTensorSpecs result;
243 result.image_width = width;
244 result.image_height = height;
245 result.color_space = ColorSpaceType_RGB;
246 result.tensor_type = input_type;
247 result.normalization_options = normalization_options;
248
249 return result;
250 }
251
252 } // namespace vision
253 } // namespace task
254 } // namespace tflite
255