1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/embedder.h"
16
17 #include <cstdint>
18 #include <functional>
19 #include <string>
20
21 #include "absl/status/status.h"
22 #include "absl/status/statusor.h"
23 #include "absl/strings/str_format.h"
24 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
25 #include "flatbuffers/reflection_generated.h" // from @flatbuffers
26 #include "tensorflow/lite/core/subgraph.h"
27 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/grafter.h"
28 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_common.h"
29 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser.h"
30 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder.h"
31 #include "tensorflow/lite/interpreter.h"
32 #include "tensorflow/lite/schema/schema_generated.h"
33 #include "tensorflow/lite/string_util.h"
34
35 namespace fb = flatbuffers;
36
37 namespace tflite {
38 namespace acceleration {
39
40 namespace {
41 // Class for building the validation entry-point graph that calls into the main
42 // graph and a metrics graph. Like this (boxes are tensors with plural names
43 // meaning possibly multiple tensors, arrows are ops and numbers in parentheses
44 // are subgraph indices):
45 // +--------------------------------------+
46 // | Graph created by this class (1) |
47 // | |
48 // | +-----------input-+ |
49 // | |jpeg input | |
50 // | +-----+-----------+ |
51 // | | |
52 // | | decode |
53 // | v |
54 // | +-----+-----------+ |
55 // | |quantized image | |
56 // | +-----+-----------+ | +-----------------------+
57 // | | | |'main_model' (0) |
58 // | | dequantize (optional) | | +---------------+ |
59 // | v | | |input +---+ |
60 // | +-----+-----------+ | | +---------------+ | |
61 // | |float image | | | ~ |
62 // | +-----+-----------+ | | +---------------+ | |
63 // | | call | | |outputs +<--+ |
64 // | +<------------------------------->+ +---------------+ |
65 // | v | | |
66 // | +-----+-----output+ +---------input+ | +-----------------------+
67 // | |actual outputs | |golden outputs| |
68 // | +-----+-----------+ +-----------+--+ |
69 // | | | |
70 // | | dequantize (optional) | |
71 // | | | |
72 // | +-----+-------------------------+-+ |
73 // | | dequantized actual and golden | |
74 // | | outputs (validation inputs) | |
75 // | +-----+---------------------------+ | +-----------------------+
76 // | | call | |'validation model' (2) |
77 // | +<------------------------------->+ |
78 // | v | | +---------------+ |
79 // | +-----+-----output+ | | |inputs +---+ |
80 // | |results | | | +---------------+ | |
81 // | +-----------------+ | | ~ |
82 // | | | +---------------+ | |
83 // | | | |outputs +<--+ |
84 // | | | +---------------+ |
85 // | | | |
86 // +--------------------------------------+ +-----------------------+
87 //
88 // It's important the 'main_model' has subgraph index 0 so that it is used as
89 // the primary subgraph by the TFLite interpreter. The other indices are
90 // arbitrary.
91 // TODO(b/172541832): Handle a main model with more than one subgraph.
92 //
93 // Note that the jpeg input is marked as an input in this graph, as TFLite
94 // graphs must have inputs. However, it will be pre-filled from the jpeg_data
95 // and doesn't need to be filled by the user of the model.
96 //
97
98 constexpr char kMetricPrefix[] = "metrics/";
99
100 class ValidationGraphBuilder {
101 public:
102 ValidationGraphBuilder(const Model* main_model,
103 std::vector<std::string> jpeg_data, float scale,
104 int64_t zero_point, const Model* validation_model,
105 const reflection::Schema* schema,
106 bool use_ondevice_cpu_for_golden);
107
108 // Builds the part of the model drawn above until the call to the validation
109 // graph. The model is used to generate golden outputs. Calls Finish on the
110 // FlatbufferBuilder.
111 absl::Status BuildIntermediateModel(fb::FlatBufferBuilder* fbb);
112 // Builds the whole model as drawn above. The subgraph_with_golden_outputs
113 // should be the result of invoking subgraph 1 on the output of
114 // BuildIntermediateModel(). Calls Finish on the FlatbufferBuilder.
115 absl::Status BuildFinalModel(fb::FlatBufferBuilder* fbb,
116 Subgraph* subgraph_with_golden_outputs);
117
118 ValidationGraphBuilder(const ValidationGraphBuilder&) = delete;
119 ValidationGraphBuilder& operator=(const ValidationGraphBuilder&) = delete;
120
121 private:
122 static const int32_t kModelVersion = 3;
123 static const int32_t kSkippedIndex = -1;
124 // Operator code numbering.
125 static const int32_t kCallOperatorCode = 0;
126 static const int32_t kDequantizeOperatorCode = 1;
127 static const int32_t kDecodeJpegOperatorCode = 2;
128 // Subgraph numbering.
129 static const int32_t kMainSubgraphIndex = 0;
130 static const int32_t kValidationSubgraphIndex = 2;
131
132 // Allocation of tensors, for communication between methods that create the
133 // tensors, the operations and the buffers.
134 // (Some of these vectors will always contain only one element, but using the
135 // same type for them simplifies the code a lot).
136 struct TensorInfo {
137 std::vector<int32_t> entrypoint_inputs;
138 std::vector<int32_t> entrypoint_outputs;
139 std::vector<int32_t> jpeg_images;
140
141 // With float main model, both quantized_images and float_images are set,
142 // and float_images is the same as main input. With a quantized model
143 // only quantized_images is set and it's the same as main input.
144 std::vector<int32_t> quantized_images;
145 std::vector<int32_t> float_images;
146
147 std::vector<int32_t> main_outputs; // First half of validation_inputs.
148 std::vector<int32_t> validation_inputs;
149 // With a float model, validation_inputs is used directly. With a quantized
150 // model, the inputs are first dequantized.
151 // Some models have a mixture of quantized outputs that need to be
152 // dequantized to floats; and integer outputs. For integer outputs
153 // kSkippedIndex is used.
154 std::vector<int32_t> dequantized_validation_inputs;
155 std::vector<int32_t> validation_outputs;
156
157 char* jpeg_buffer_contents = nullptr;
158 int32_t jpeg_buffer_length = -1;
159 int32_t jpeg_height = -1;
160 int32_t jpeg_width = -1;
161
~TensorInfotflite::acceleration::__anond4c55ce60111::ValidationGraphBuilder::TensorInfo162 ~TensorInfo() { free(jpeg_buffer_contents); }
163 };
164
MakeModel(bool intermediate_only,Subgraph * subgraph_with_golden_outputs)165 absl::StatusOr<fb::Offset<Model>> MakeModel(
166 bool intermediate_only, Subgraph* subgraph_with_golden_outputs) {
167 TensorInfo tensor_info;
168 auto operator_codes = OperatorCodes();
169 if (!operator_codes.ok()) {
170 return operator_codes.status();
171 }
172 auto subgraphs = SubGraphs(intermediate_only, &tensor_info);
173 if (!subgraphs.ok()) {
174 return subgraphs.status();
175 }
176 auto buffers =
177 Buffers(intermediate_only, tensor_info, subgraph_with_golden_outputs);
178 if (!buffers.ok()) {
179 return buffers.status();
180 }
181 return CreateModel(fbb_, kModelVersion, *operator_codes, *subgraphs,
182 fbb_.CreateString("validation"), *buffers,
183 /* metadata_buffer */ 0, /* metadata */ 0,
184 /* signature_defs */ 0);
185 }
186
187 absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<OperatorCode>>>>
OperatorCodes()188 OperatorCodes() {
189 #define RET_CHECK_INDEX(constant, code_index) \
190 do { \
191 if ((constant) != (code_index)) { \
192 return absl::InternalError(absl::StrFormat( \
193 "Operator code indexing mismatch %s (%d) != %s (%d)", #constant, \
194 (constant), #code_index, (code_index))); \
195 } \
196 } while (0)
197 std::vector<fb::Offset<OperatorCode>> codes;
198 RET_CHECK_INDEX(kCallOperatorCode, codes.size());
199 codes.push_back(CreateOperatorCode(fbb_, BuiltinOperator_CUSTOM,
200 fbb_.CreateString("validation/call")));
201 RET_CHECK_INDEX(kDequantizeOperatorCode, codes.size());
202 codes.push_back(CreateOperatorCode(fbb_, BuiltinOperator_DEQUANTIZE));
203 RET_CHECK_INDEX(kDecodeJpegOperatorCode, codes.size());
204 codes.push_back(
205 CreateOperatorCode(fbb_, BuiltinOperator_CUSTOM,
206 fbb_.CreateString("validation/decode_jpeg")));
207 return fbb_.CreateVector(codes);
208 #undef RET_CHECK_INDEX
209 }
210
Tensors(bool intermediate_only,TensorInfo * tensor_info)211 absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<Tensor>>>> Tensors(
212 bool intermediate_only, TensorInfo* tensor_info) {
213 std::vector<fb::Offset<Tensor>> tensors;
214 int buffer_count = 0;
215
216 // Copy tensors from a source subgraph, overriding the batch_size where
217 // necessary (the called subgraph always uses batch size 1, the calling
218 // subgraph always uses batch size equal jpeg_data_.size()).
219 auto copy =
220 [&tensors, this, &buffer_count](
221 const SubGraph* from_subgraph, const fb::Vector<int32_t>* indices,
222 std::vector<int32_t>* store_indices_into, int batch_size,
223 const std::string prefix = "",
224 std::function<absl::StatusOr<bool>(const Tensor*, int)> filter =
225 nullptr) -> absl::Status {
226 int counter = 0;
227 for (auto index = indices->cbegin(); index != indices->cend();
228 index++, counter++) {
229 const Tensor* tensor = from_subgraph->tensors()->Get(*index);
230 if (filter) {
231 auto statusor = filter(tensor, counter);
232 if (!statusor.ok()) {
233 return statusor.status();
234 } else if (!statusor.value()) {
235 store_indices_into->push_back(kSkippedIndex);
236 continue;
237 }
238 }
239 std::vector<int32_t> shape{tensor->shape()->cbegin(),
240 tensor->shape()->cend()};
241 if (shape.size() >= 2 && shape[0] == 1 && batch_size > 0) {
242 shape[0] = batch_size;
243 }
244 std::vector<int32_t> shape_signature;
245 if (tensor->shape_signature()) {
246 shape_signature.assign(tensor->shape_signature()->cbegin(),
247 tensor->shape_signature()->cend());
248 if (shape_signature.size() >= 2 && shape_signature[0] == 1 &&
249 batch_size > 0) {
250 shape_signature[0] = batch_size;
251 }
252 }
253 auto quantization_parameters = helper_.CopyTable(
254 "tflite.QuantizationParameters", tensor->quantization());
255 if (!quantization_parameters.ok()) {
256 return quantization_parameters.status();
257 }
258 auto sparsity_parameters =
259 helper_.CopyTable("tflite.SparsityParameters", tensor->sparsity());
260 if (!sparsity_parameters.ok()) {
261 return sparsity_parameters.status();
262 }
263 store_indices_into->push_back(tensors.size());
264 std::string name = tensor->name()->str();
265 if (!prefix.empty() && name.find(prefix) != 0) { // NOLINT
266 name = prefix + name;
267 }
268 tensors.push_back(CreateTensor(
269 fbb_, fbb_.CreateVector(shape), tensor->type(), buffer_count,
270 fbb_.CreateString(name), *quantization_parameters,
271 tensor->is_variable(), *sparsity_parameters,
272 shape_signature.empty() ? 0 : fbb_.CreateVector(shape_signature)));
273 buffer_count++;
274 }
275 return absl::OkStatus();
276 };
277 // Input image, jpeg data.
278 tensor_info->jpeg_images.push_back(tensors.size());
279 DynamicBuffer jpeg_buffer;
280 for (int i = 0; i < jpeg_data_.size(); i++) {
281 jpeg_buffer.AddString(jpeg_data_[i].data(), jpeg_data_[i].size());
282 }
283 tensor_info->jpeg_buffer_length =
284 jpeg_buffer.WriteToBuffer(&(tensor_info->jpeg_buffer_contents));
285 tensors.push_back(CreateTensor(
286 fbb_,
287 fbb_.CreateVector(
288 std::vector<int32_t>{static_cast<int32_t>(jpeg_data_.size())}),
289 TensorType::TensorType_STRING, buffer_count,
290 fbb_.CreateString("call/jpeg_images")));
291 buffer_count++;
292
293 // Input image.
294 const SubGraph* main_subgraph = main_model_->subgraphs()->Get(0);
295 const Tensor* input_tensor =
296 main_subgraph->tensors()->Get(main_subgraph->inputs()->Get(0));
297 tensor_info->jpeg_height = input_tensor->shape()->Get(1);
298 tensor_info->jpeg_width = input_tensor->shape()->Get(2);
299 if (input_tensor->type() == TensorType_FLOAT32) {
300 // Quantized.
301 std::vector<int32_t> input_shape{input_tensor->shape()->cbegin(),
302 input_tensor->shape()->cend()};
303 input_shape[0] = static_cast<int32_t>(jpeg_data_.size());
304 tensor_info->quantized_images.push_back(tensors.size());
305 tensors.push_back(CreateTensor(
306 fbb_, fbb_.CreateVector(input_shape), TensorType::TensorType_UINT8,
307 buffer_count, fbb_.CreateString("call/quant_image"),
308 CreateQuantizationParameters(
309 fbb_, 0, 0, fbb_.CreateVector(std::vector<float>{scale_}),
310 fbb_.CreateVector(std::vector<int64_t>{zero_point_}))));
311 buffer_count++;
312 // Float.
313 tensor_info->float_images.push_back(tensors.size());
314 tensors.push_back(CreateTensor(
315 fbb_, fbb_.CreateVector(input_shape), TensorType::TensorType_FLOAT32,
316 buffer_count, fbb_.CreateString("call/float_image")));
317 buffer_count++;
318 } else {
319 // Quantized only.
320 auto status = copy(main_model_->subgraphs()->Get(0),
321 main_model_->subgraphs()->Get(0)->inputs(),
322 &tensor_info->quantized_images, jpeg_data_.size());
323 if (!status.ok()) {
324 return status;
325 }
326 }
327
328 // Validation inputs, actual.
329 auto status = copy(main_model_->subgraphs()->Get(0),
330 main_model_->subgraphs()->Get(0)->outputs(),
331 &tensor_info->main_outputs, jpeg_data_.size());
332 if (!status.ok()) {
333 return status;
334 }
335 if (intermediate_only) {
336 return fbb_.CreateVector(tensors);
337 }
338 // Validation inputs, golden.
339 tensor_info->validation_inputs = tensor_info->main_outputs;
340 status = copy(main_model_->subgraphs()->Get(0),
341 main_model_->subgraphs()->Get(0)->outputs(),
342 &tensor_info->validation_inputs, jpeg_data_.size());
343 if (!status.ok()) {
344 return status;
345 }
346 // Entrypoint inputs. Golden first (validator relies on this).
347 for (int i = tensor_info->validation_inputs.size() / 2;
348 i < tensor_info->validation_inputs.size(); i++) {
349 tensor_info->entrypoint_inputs.push_back(
350 tensor_info->validation_inputs[i]);
351 }
352 tensor_info->entrypoint_inputs.push_back(tensor_info->jpeg_images[0]);
353 // Validation inputs, dequantized.
354 status = copy(
355 validation_model_->subgraphs()->Get(0),
356 validation_model_->subgraphs()->Get(0)->inputs(),
357 &tensor_info->dequantized_validation_inputs, jpeg_data_.size(), "",
358 [&tensors, &tensor_info, this](const Tensor* validation_model_input,
359 int i) -> absl::StatusOr<bool> {
360 // validation_model_input is the tensor for metrics calculation.
361 // validation_graph_input is the under-construction graph will be
362 // given to the metrics calculation but need to be dequantized first.
363 const Tensor* validation_graph_input = fb::GetTemporaryPointer(
364 fbb_, tensors[tensor_info->validation_inputs[i]]);
365 if (validation_model_input->type() == TensorType_FLOAT32 &&
366 (validation_graph_input->type() == TensorType_UINT8 ||
367 validation_graph_input->type() == TensorType_INT8)) {
368 return true;
369 } else if (validation_model_input->type() !=
370 validation_graph_input->type()) {
371 const char* name = "(null)";
372 if (validation_model_input->name()) {
373 name = validation_model_input->name()->c_str();
374 }
375 return absl::InvalidArgumentError(
376 absl::StrFormat("Validation model input %s with type %d is "
377 "incompatible with main model output type %d",
378 name, validation_model_input->type(),
379 validation_graph_input->type()));
380 } else {
381 return false;
382 }
383 });
384 if (!status.ok()) {
385 return status;
386 }
387 // Validation outputs.
388 status = copy(validation_model_->subgraphs()->Get(0),
389 validation_model_->subgraphs()->Get(0)->outputs(),
390 &tensor_info->validation_outputs, jpeg_data_.size(),
391 kMetricPrefix);
392 if (!status.ok()) {
393 return status;
394 }
395 // Outputs from entrypoint graph
396 // Actuals first (validator relies on this);
397 for (int i = 0; i < tensor_info->validation_inputs.size() / 2; i++) {
398 tensor_info->entrypoint_outputs.push_back(
399 tensor_info->validation_inputs[i]);
400 }
401 // Metrics.
402 for (int i = 0; i < tensor_info->validation_outputs.size(); i++) {
403 tensor_info->entrypoint_outputs.push_back(
404 tensor_info->validation_outputs[i]);
405 }
406 return fbb_.CreateVector(tensors);
407 }
408
409 // Create the options for the custom call op (see call.cc for the options
410 // format).
CallOpCustomOptions(int subgraph)411 fb::Offset<fb::Vector<uint8_t>> CallOpCustomOptions(int subgraph) {
412 flexbuffers::Builder fbb;
413 fbb.Map([&] {
414 fbb.Int("subgraph_index", subgraph);
415 fbb.Int("loop_count", static_cast<int32_t>(jpeg_data_.size()));
416 });
417 fbb.Finish();
418 return fbb_.CreateVector(fbb.GetBuffer());
419 }
420
421 // Create the options for the custom jpeg op (see decode_jpeg.cc for the
422 // options format).
JpegOpCustomOptions(int height,int width,int channels)423 fb::Offset<fb::Vector<uint8_t>> JpegOpCustomOptions(int height, int width,
424 int channels) {
425 flexbuffers::Builder fbb;
426 fbb.Map([&] {
427 fbb.Int("height", height);
428 fbb.Int("width", width);
429 fbb.Int("channels", channels);
430 fbb.Int("num_images", jpeg_data_.size());
431 });
432 fbb.Finish();
433 return fbb_.CreateVector(fbb.GetBuffer());
434 }
435
Operators(bool intermediate_only,const TensorInfo & tensor_info)436 fb::Offset<fb::Vector<fb::Offset<Operator>>> Operators(
437 bool intermediate_only, const TensorInfo& tensor_info) {
438 std::vector<fb::Offset<Operator>> ops;
439 // Jpeg decode.
440 ops.push_back(
441 CreateOperator(fbb_, kDecodeJpegOperatorCode,
442 fbb_.CreateVector(tensor_info.jpeg_images),
443 fbb_.CreateVector(tensor_info.quantized_images),
444 tflite::BuiltinOptions_NONE, 0,
445 JpegOpCustomOptions(tensor_info.jpeg_height,
446 tensor_info.jpeg_width, 3)));
447 if (!tensor_info.float_images.empty()) {
448 // Dequantize.
449 ops.push_back(
450 CreateOperator(fbb_, kDequantizeOperatorCode,
451 fbb_.CreateVector(tensor_info.quantized_images),
452 fbb_.CreateVector(tensor_info.float_images),
453 BuiltinOptions_DequantizeOptions, 0));
454 // Call main model.
455 ops.push_back(CreateOperator(fbb_, kCallOperatorCode,
456 fbb_.CreateVector(tensor_info.float_images),
457 fbb_.CreateVector(tensor_info.main_outputs),
458 tflite::BuiltinOptions_NONE, 0,
459 CallOpCustomOptions(kMainSubgraphIndex),
460 tflite::CustomOptionsFormat_FLEXBUFFERS));
461 } else {
462 // Call main model.
463 ops.push_back(
464 CreateOperator(fbb_, kCallOperatorCode,
465 fbb_.CreateVector(tensor_info.quantized_images),
466 fbb_.CreateVector(tensor_info.main_outputs),
467 tflite::BuiltinOptions_NONE, 0,
468 CallOpCustomOptions(kMainSubgraphIndex),
469 tflite::CustomOptionsFormat_FLEXBUFFERS));
470 }
471 if (intermediate_only) {
472 return fbb_.CreateVector(ops);
473 }
474 // Call validation model.
475 std::vector<int32_t> validation_input_indices;
476 for (int i = 0; i < tensor_info.dequantized_validation_inputs.size(); i++) {
477 int32_t validation_input_index;
478 if (tensor_info.dequantized_validation_inputs[i] == kSkippedIndex) {
479 validation_input_index = tensor_info.validation_inputs[i];
480 } else {
481 validation_input_index = tensor_info.dequantized_validation_inputs[i];
482 std::vector<int32_t> dequantize_inputs{
483 tensor_info.validation_inputs[i]};
484 std::vector<int32_t> dequantize_outputs{
485 tensor_info.dequantized_validation_inputs[i]};
486 ops.push_back(CreateOperator(fbb_, kDequantizeOperatorCode,
487 fbb_.CreateVector(dequantize_inputs),
488 fbb_.CreateVector(dequantize_outputs),
489 BuiltinOptions_DequantizeOptions, 0));
490 }
491 validation_input_indices.push_back(validation_input_index);
492 }
493 ops.push_back(CreateOperator(
494 fbb_, kCallOperatorCode, fbb_.CreateVector(validation_input_indices),
495 fbb_.CreateVector(tensor_info.validation_outputs),
496 tflite::BuiltinOptions_NONE, 0,
497 CallOpCustomOptions(kValidationSubgraphIndex),
498 tflite::CustomOptionsFormat_FLEXBUFFERS));
499 return fbb_.CreateVector(ops);
500 }
501
SubGraphs(bool intermediate_only,TensorInfo * tensor_info)502 absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<SubGraph>>>> SubGraphs(
503 bool intermediate_only, TensorInfo* tensor_info) {
504 auto tensors = Tensors(intermediate_only, tensor_info);
505 if (!tensors.ok()) {
506 return tensors.status();
507 }
508 std::vector<fb::Offset<SubGraph>> graphs;
509 if (intermediate_only) {
510 graphs.push_back(CreateSubGraph(
511 fbb_, *tensors, fbb_.CreateVector(tensor_info->jpeg_images),
512 fbb_.CreateVector(tensor_info->main_outputs),
513 Operators(intermediate_only, *tensor_info),
514 fbb_.CreateString("call")));
515 } else {
516 graphs.push_back(CreateSubGraph(
517 fbb_, *tensors, fbb_.CreateVector(tensor_info->entrypoint_inputs),
518 fbb_.CreateVector(tensor_info->entrypoint_outputs),
519 Operators(intermediate_only, *tensor_info),
520 fbb_.CreateString("call")));
521 }
522 return fbb_.CreateVector(graphs);
523 }
524
Buffers(bool intermediate_only,const TensorInfo & tensor_info,Subgraph * subgraph_with_golden_outputs)525 absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<Buffer>>>> Buffers(
526 bool intermediate_only, const TensorInfo& tensor_info,
527 Subgraph* subgraph_with_golden_outputs) {
528 std::vector<fb::Offset<Buffer>> buffers;
529
530 // The buffers created in this method map 1:1 to the tensors created in
531 // Tensors() - a tensor at index X uses buffer at index X. The numbering
532 // is checked along the way using the RET_CHECK_INDEX macro below.
533 #define RET_CHECK_INDEX(tensor_index, buffer_index) \
534 do { \
535 if ((tensor_index) != (buffer_index)) { \
536 return absl::InternalError(absl::StrFormat( \
537 "%s:%d, Tensor/buffer indexing mismatch %s (%d) != %s (%d)", \
538 __FILE__, __LINE__, #tensor_index, (tensor_index), #buffer_index, \
539 (buffer_index))); \
540 } \
541 } while (0)
542
543 // Jpeg input.
544 RET_CHECK_INDEX(tensor_info.jpeg_images.size(), 1);
545 RET_CHECK_INDEX(tensor_info.jpeg_images[0], buffers.size());
546 std::vector<uint8_t> jpeg_buffer_vec{
547 reinterpret_cast<const uint8_t*>(tensor_info.jpeg_buffer_contents),
548 reinterpret_cast<const uint8_t*>(tensor_info.jpeg_buffer_contents) +
549 tensor_info.jpeg_buffer_length};
550 buffers.push_back(CreateBuffer(fbb_, fbb_.CreateVector(jpeg_buffer_vec)));
551
552 // Decoded and dequantized image.
553 RET_CHECK_INDEX(tensor_info.quantized_images.size(), 1);
554 RET_CHECK_INDEX(tensor_info.quantized_images[0], buffers.size());
555 buffers.push_back(CreateBuffer(fbb_));
556 if (!tensor_info.float_images.empty()) {
557 RET_CHECK_INDEX(tensor_info.float_images.size(), 1);
558 RET_CHECK_INDEX(tensor_info.float_images[0], buffers.size());
559 buffers.push_back(CreateBuffer(fbb_));
560 }
561
562 // Main graph outputs / first half of validation inputs.
563 auto main_subgraph = main_model_->subgraphs()->Get(0);
564 RET_CHECK_INDEX(main_subgraph->outputs()->size(),
565 tensor_info.main_outputs.size());
566 int main_output_index = 0;
567 int validation_graph_input_index = 0;
568 for (auto i = main_subgraph->outputs()->cbegin();
569 i != main_subgraph->outputs()->cend(); i++) {
570 RET_CHECK_INDEX(tensor_info.main_outputs[main_output_index],
571 buffers.size());
572 main_output_index++;
573 if (!intermediate_only) {
574 RET_CHECK_INDEX(
575 tensor_info.validation_inputs[validation_graph_input_index],
576 buffers.size());
577 validation_graph_input_index++;
578 }
579 auto t = main_subgraph->tensors()->Get(*i);
580 auto status = helper_.CopyTableToVector(
581 "tflite.Buffer", main_model_->buffers()->Get(t->buffer()), &buffers);
582 if (!status.ok()) {
583 return status;
584 }
585 }
586 if (intermediate_only) {
587 return fbb_.CreateVector(buffers);
588 }
589
590 // Golden outputs / second half of validation inputs.
591 RET_CHECK_INDEX(tensor_info.validation_inputs.size(),
592 validation_graph_input_index +
593 subgraph_with_golden_outputs->outputs().size());
594 for (auto i : subgraph_with_golden_outputs->outputs()) {
595 RET_CHECK_INDEX(
596 tensor_info.validation_inputs[validation_graph_input_index],
597 buffers.size());
598 validation_graph_input_index++;
599 auto t = subgraph_with_golden_outputs->tensor(i);
600 if (!use_ondevice_cpu_for_golden_) {
601 std::vector<uint8_t> output_data{
602 reinterpret_cast<const uint8_t*>(t->data.raw),
603 reinterpret_cast<const uint8_t*>(t->data.raw + t->bytes)};
604 buffers.push_back(CreateBuffer(fbb_, fbb_.CreateVector(output_data)));
605 } else {
606 buffers.push_back(CreateBuffer(fbb_));
607 }
608 }
609
610 auto validation_model_subgraph = validation_model_->subgraphs()->Get(0);
611 // Dequantized validation inputs.
612 RET_CHECK_INDEX(tensor_info.dequantized_validation_inputs.size(),
613 validation_model_subgraph->inputs()->size());
614 int validation_graph_dequantized_input_index = 0;
615 for (auto i = validation_model_subgraph->inputs()->cbegin();
616 i != validation_model_subgraph->inputs()->cend(); i++) {
617 if (tensor_info.dequantized_validation_inputs
618 [validation_graph_dequantized_input_index] == kSkippedIndex) {
619 validation_graph_dequantized_input_index++;
620 continue;
621 }
622 RET_CHECK_INDEX(tensor_info.dequantized_validation_inputs
623 [validation_graph_dequantized_input_index],
624 buffers.size());
625 validation_graph_dequantized_input_index++;
626 auto t = validation_model_subgraph->tensors()->Get(*i);
627 auto status = helper_.CopyTableToVector(
628 "tflite.Buffer", validation_model_->buffers()->Get(t->buffer()),
629 &buffers);
630 if (!status.ok()) {
631 return status;
632 }
633 }
634
635 // Validation outputs.
636 RET_CHECK_INDEX(tensor_info.validation_outputs.size(),
637 validation_model_subgraph->outputs()->size());
638 int validation_graph_output_index = 0;
639 for (auto i = validation_model_subgraph->outputs()->cbegin();
640 i != validation_model_subgraph->outputs()->cend(); i++) {
641 RET_CHECK_INDEX(
642 tensor_info.validation_outputs[validation_graph_output_index],
643 buffers.size());
644 validation_graph_output_index++;
645 auto t = validation_model_subgraph->tensors()->Get(*i);
646 auto status = helper_.CopyTableToVector(
647 "tflite.Buffer", validation_model_->buffers()->Get(t->buffer()),
648 &buffers);
649 if (!status.ok()) {
650 return status;
651 }
652 }
653 return fbb_.CreateVector(buffers);
654 #undef RET_CHECK_INDEX
655 }
656
657 const Model* main_model_;
658 std::vector<std::string> jpeg_data_;
659 float scale_;
660 int64_t zero_point_;
661 const Model* validation_model_;
662 const reflection::Schema* schema_;
663 fb::FlatBufferBuilder fbb_;
664 FlatbufferHelper helper_;
665 bool use_ondevice_cpu_for_golden_;
666 };
667
668 // Define the static constant members. Without definition they can be used as
669 // compile-time constants but can not be passed by reference (e.g., used with
670 // absl::StrFormat).
671 const int32_t ValidationGraphBuilder::kModelVersion;
672 const int32_t ValidationGraphBuilder::kSkippedIndex;
673 const int32_t ValidationGraphBuilder::kCallOperatorCode;
674 const int32_t ValidationGraphBuilder::kDequantizeOperatorCode;
675 const int32_t ValidationGraphBuilder::kDecodeJpegOperatorCode;
676 const int32_t ValidationGraphBuilder::kMainSubgraphIndex;
677 const int32_t ValidationGraphBuilder::kValidationSubgraphIndex;
678
ValidationGraphBuilder(const Model * main_model,std::vector<std::string> jpeg_data,float scale,int64_t zero_point,const Model * validation_model,const reflection::Schema * schema,bool use_ondevice_cpu_for_golden)679 ValidationGraphBuilder::ValidationGraphBuilder(
680 const Model* main_model, std::vector<std::string> jpeg_data, float scale,
681 int64_t zero_point, const Model* validation_model,
682 const reflection::Schema* schema, bool use_ondevice_cpu_for_golden)
683 : main_model_(main_model),
684 jpeg_data_(jpeg_data),
685 scale_(scale),
686 zero_point_(zero_point),
687 validation_model_(validation_model),
688 schema_(schema),
689 helper_(&fbb_, schema_),
690 use_ondevice_cpu_for_golden_(use_ondevice_cpu_for_golden) {}
691
BuildIntermediateModel(fb::FlatBufferBuilder * fbb)692 absl::Status ValidationGraphBuilder::BuildIntermediateModel(
693 fb::FlatBufferBuilder* fbb) {
694 fbb_.Reset();
695 auto model = MakeModel(/* intermediate_only */ true,
696 /* subgraph_with_golden_outputs */ nullptr);
697 if (!model.ok()) {
698 return model.status();
699 }
700 fbb_.Finish(*model, "TFL3");
701 std::vector<const Model*> models{main_model_,
702 fb::GetRoot<Model>(fbb_.GetBufferPointer())};
703 std::vector<std::string> subgraph_names_not_important(2);
704 return CombineModels(fbb, models, subgraph_names_not_important, schema_);
705 }
706
BuildFinalModel(fb::FlatBufferBuilder * fbb,Subgraph * subgraph_with_golden_outputs)707 absl::Status ValidationGraphBuilder::BuildFinalModel(
708 fb::FlatBufferBuilder* fbb, Subgraph* subgraph_with_golden_outputs) {
709 fbb_.Reset();
710 auto model =
711 MakeModel(/* intermediate_only */ false, subgraph_with_golden_outputs);
712 if (!model.ok()) {
713 return model.status();
714 }
715 fbb_.Finish(*model, "TFL3");
716 std::vector<const Model*> models{main_model_,
717 fb::GetRoot<Model>(fbb_.GetBufferPointer()),
718 validation_model_};
719 std::vector<std::string> subgraph_names;
720 auto main_subgraph_name = main_model_->subgraphs()->Get(0)->name();
721 subgraph_names.push_back(main_subgraph_name ? main_subgraph_name->str() : "");
722 subgraph_names.push_back("VALIDATION:main");
723 subgraph_names.push_back("VALIDATION:metrics");
724 return CombineModels(fbb, models, subgraph_names, schema_);
725 }
726
DescribeShape(const fb::Vector<int32_t> * shape)727 std::string DescribeShape(const fb::Vector<int32_t>* shape) {
728 std::string desc = "[";
729 for (int i = 0; i < shape->size(); i++) {
730 if (i != 0) {
731 desc += ", ";
732 }
733 desc += absl::StrFormat("%d", shape->Get(i));
734 }
735 desc += "]";
736 return desc;
737 }
738
739 } // namespace
740
Embedder(const Model * main_model,const std::vector<std::string> & jpeg_data,float scale,int64_t zero_point,const Model * validation_model,const reflection::Schema * schema,bool use_ondevice_cpu_for_golden)741 Embedder::Embedder(const Model* main_model,
742 const std::vector<std::string>& jpeg_data, float scale,
743 int64_t zero_point, const Model* validation_model,
744 const reflection::Schema* schema,
745 bool use_ondevice_cpu_for_golden)
746 : main_model_(main_model),
747 jpeg_data_(jpeg_data),
748 scale_(scale),
749 zero_point_(zero_point),
750 validation_model_(validation_model),
751 schema_(schema),
752 use_ondevice_cpu_for_golden_(use_ondevice_cpu_for_golden) {}
753
ValidateInputs()754 absl::Status Embedder::ValidateInputs() {
755 #define VALIDATE(condition, ...) \
756 if (!(condition)) { \
757 return absl::InvalidArgumentError(absl::StrFormat(__VA_ARGS__)); \
758 }
759 VALIDATE(main_model_, "main_model may not be null");
760 VALIDATE(main_model_->subgraphs()->size(), "main model must have subgraphs");
761 const SubGraph* main_subgraph = main_model_->subgraphs()->Get(0);
762 VALIDATE(main_subgraph->inputs()->size() == 1,
763 "main subgraph must have 1 input (got %d)",
764 main_subgraph->inputs()->size());
765 const auto* shape =
766 main_subgraph->tensors()->Get(main_subgraph->inputs()->Get(0))->shape();
767 VALIDATE(shape->size() == 4,
768 "main subgraph input must have 4 dimensions (got %d)",
769 shape->size());
770 VALIDATE(shape->Get(0) == 1,
771 "main subgraph input must have batch size (got %d)", shape->Get(0));
772 VALIDATE(shape->Get(3) == 1 || shape->Get(3) == 3,
773 "main subgraph input must have 1 or 3 channels (got %d)",
774 shape->Get(3));
775
776 VALIDATE(!jpeg_data_.empty(), "must have at least 1 jpeg input");
777 int jpeg_number = 0;
778 for (const std::string& jpeg_image_data : jpeg_data_) {
779 int width, height, components;
780 decode_jpeg_kernel::JpegHeader header{0};
781 auto status = decode_jpeg_kernel::ReadJpegHeader(
782 {jpeg_image_data.data(), static_cast<int>(jpeg_image_data.size())},
783 &header);
784 VALIDATE(status.code == kTfLiteOk,
785 "Failed to decompress jpeg data at index %d: %s", jpeg_number,
786 status.error_message.c_str());
787 width = header.width;
788 height = header.height;
789 components = header.channels;
790 VALIDATE(height == shape->Get(1) && width == shape->Get(2) &&
791 components == shape->Get(3),
792 "Jpeg input at index %d has different size from input tensor "
793 "(jpeg h: %d, w: %d, c: %d; tensor h: %d, w: %d, c: %d)",
794 jpeg_number, height, width, components, shape->Get(1),
795 shape->Get(2), shape->Get(3));
796 jpeg_number++;
797 }
798
799 int main_output_count = main_subgraph->outputs()->size();
800 VALIDATE(main_output_count > 0,
801 "main subgraph must have at least 1 output (got %d)",
802 main_output_count);
803 VALIDATE(validation_model_->subgraphs()->size(),
804 "validation model must have subgraphs");
805 const SubGraph* validation_subgraph = validation_model_->subgraphs()->Get(0);
806 int validation_input_count = validation_subgraph->inputs()->size();
807 VALIDATE(
808 validation_input_count == main_output_count * 2,
809 "validation subgraph input count must be 2 times main subgraph output "
810 "count (validation output count: %d, main subgraph output count: %d)",
811 validation_input_count, main_output_count);
812 for (int i = 0; i < main_output_count; i++) {
813 auto main_output_tensor =
814 main_subgraph->tensors()->Get(main_subgraph->outputs()->Get(i));
815 auto main_output_shape = DescribeShape(main_output_tensor->shape());
816 VALIDATE(main_output_shape != "[]",
817 "All main outputs must be tensors, %d is a scalar", i);
818 VALIDATE(main_output_tensor->name()->str().find(kMetricPrefix) != 0,
819 "Main output %d name %s clashes with metrics/ prefix", i,
820 main_output_tensor->name()->c_str());
821 auto validation_input_shape_1 =
822 DescribeShape(validation_subgraph->tensors()
823 ->Get(validation_subgraph->inputs()->Get(i))
824 ->shape());
825 auto validation_input_shape_2 = DescribeShape(
826 validation_subgraph->tensors()
827 ->Get(validation_subgraph->inputs()->Get(main_output_count + i))
828 ->shape());
829 VALIDATE(main_output_shape == validation_input_shape_1,
830 "Main output %d dimensions %s do not match validation input %d "
831 "dimensions %s",
832 i, main_output_shape, i, validation_input_shape_1);
833 VALIDATE(main_output_shape == validation_input_shape_2,
834 "Main output %d dimensions %s do not match validation input %d "
835 "dimensions %s",
836 i, main_output_shape, main_output_count + i,
837 validation_input_shape_2);
838 }
839 int validation_output_count = validation_subgraph->outputs()->size();
840 VALIDATE(validation_output_count >= 2,
841 "validation output count must be at least 2 (got "
842 "%d)",
843 validation_output_count);
844 bool seen_ok = false;
845 const std::string kOk = "ok";
846 const std::string kPrefixedOk = kMetricPrefix + kOk;
847 std::string names = "";
848 for (int i = 0; i < validation_output_count; i++) {
849 const Tensor* t = validation_subgraph->tensors()->Get(
850 validation_subgraph->outputs()->Get(i));
851 VALIDATE(t->shape()->size(),
852 "validation outputs must be tensors, %d is a scalar", i);
853 seen_ok = (seen_ok || (kOk == t->name()->str()) ||
854 (kPrefixedOk == t->name()->str()));
855 if (i != 0) {
856 names += ", ";
857 }
858 names += t->name()->str();
859 }
860 VALIDATE(seen_ok, "validation must have an output named 'ok' (saw %s)",
861 names);
862 #undef VALIDATE
863 return absl::OkStatus();
864 }
865
CreateModelWithEmbeddedValidation(fb::FlatBufferBuilder * fbb,ops::builtin::BuiltinOpResolver * resolver)866 absl::Status Embedder::CreateModelWithEmbeddedValidation(
867 fb::FlatBufferBuilder* fbb, ops::builtin::BuiltinOpResolver* resolver) {
868 auto status = ValidateInputs();
869 if (!status.ok()) {
870 return status;
871 }
872 fb::FlatBufferBuilder intermediate_fbb;
873 ValidationGraphBuilder builder(main_model_, jpeg_data_, scale_, zero_point_,
874 validation_model_, schema_,
875 use_ondevice_cpu_for_golden_);
876 status = builder.BuildIntermediateModel(&intermediate_fbb);
877 if (!status.ok()) {
878 return status;
879 }
880 auto intermediate_model = FlatBufferModel::VerifyAndBuildFromBuffer(
881 reinterpret_cast<const char*>(intermediate_fbb.GetBufferPointer()),
882 intermediate_fbb.GetSize());
883 if (!intermediate_model) {
884 return absl::InternalError("Failed to load intermediate model");
885 }
886 std::unique_ptr<Interpreter> interpreter;
887 InterpreterBuilder(*intermediate_model, *resolver)(&interpreter);
888 if (!interpreter) {
889 return absl::InternalError(
890 "Failed to build interpreter from intermediate model");
891 }
892 Subgraph* subgraph = interpreter->subgraph(1);
893 if (subgraph->AllocateTensors() != kTfLiteOk) {
894 return absl::InternalError(
895 "Failed to AllocateTensors() on validation subgraph of intermediate "
896 "model");
897 }
898 if (subgraph->Invoke() != kTfLiteOk) {
899 return absl::InternalError(
900 "Failed to Invoke() on validation subgraph of intermediate model");
901 }
902 return builder.BuildFinalModel(fbb, subgraph);
903 }
904
905 } // namespace acceleration
906 } // namespace tflite
907