• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/embedder.h"
16 
17 #include <cstdint>
18 #include <functional>
19 #include <string>
20 
21 #include "absl/status/status.h"
22 #include "absl/status/statusor.h"
23 #include "absl/strings/str_format.h"
24 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
25 #include "flatbuffers/reflection_generated.h"  // from @flatbuffers
26 #include "tensorflow/lite/core/subgraph.h"
27 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/grafter.h"
28 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_common.h"
29 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser.h"
30 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder.h"
31 #include "tensorflow/lite/interpreter.h"
32 #include "tensorflow/lite/schema/schema_generated.h"
33 #include "tensorflow/lite/string_util.h"
34 
35 namespace fb = flatbuffers;
36 
37 namespace tflite {
38 namespace acceleration {
39 
40 namespace {
41 // Class for building the validation entry-point graph that calls into the main
42 // graph and a metrics graph. Like this (boxes are tensors with plural names
43 // meaning possibly multiple tensors, arrows are ops and numbers in parentheses
44 // are subgraph indices):
45 // +--------------------------------------+
46 // | Graph created by this class (1)      |
47 // |                                      |
48 // | +-----------input-+                  |
49 // | |jpeg input       |                  |
50 // | +-----+-----------+                  |
51 // |       |                              |
52 // |       | decode                       |
53 // |       v                              |
54 // | +-----+-----------+                  |
55 // | |quantized image  |                  |
56 // | +-----+-----------+                  |  +-----------------------+
57 // |       |                              |  |'main_model' (0)       |
58 // |       | dequantize (optional)        |  | +---------------+     |
59 // |       v                              |  | |input          +---+ |
60 // | +-----+-----------+                  |  | +---------------+   | |
61 // | |float image      |                  |  |                     ~ |
62 // | +-----+-----------+                  |  | +---------------+   | |
63 // |       |  call                        |  | |outputs        +<--+ |
64 // |       +<------------------------------->+ +---------------+     |
65 // |       v                              |  |                       |
66 // | +-----+-----output+ +---------input+ |  +-----------------------+
67 // | |actual outputs   | |golden outputs| |
68 // | +-----+-----------+ +-----------+--+ |
69 // |       |                         |    |
70 // |       | dequantize (optional)   |    |
71 // |       |                         |    |
72 // | +-----+-------------------------+-+  |
73 // | | dequantized actual and golden   |  |
74 // | | outputs (validation inputs)     |  |
75 // | +-----+---------------------------+  |  +-----------------------+
76 // |       |  call                        |  |'validation model' (2) |
77 // |       +<------------------------------->+                       |
78 // |       v                              |  | +---------------+     |
79 // | +-----+-----output+                  |  | |inputs         +---+ |
80 // | |results          |                  |  | +---------------+   | |
81 // | +-----------------+                  |  |                     ~ |
82 // |                                      |  | +---------------+   | |
83 // |                                      |  | |outputs        +<--+ |
84 // |                                      |  | +---------------+     |
85 // |                                      |  |                       |
86 // +--------------------------------------+  +-----------------------+
87 //
88 // It's important the 'main_model' has subgraph index 0 so that it is used as
89 // the primary subgraph by the TFLite interpreter. The other indices are
90 // arbitrary.
91 // TODO(b/172541832): Handle a main model with more than one subgraph.
92 //
93 // Note that the jpeg input is marked as an input in this graph, as TFLite
94 // graphs must have inputs. However, it will be pre-filled from the jpeg_data
95 // and doesn't need to be filled by the user of the model.
96 //
97 
98 constexpr char kMetricPrefix[] = "metrics/";
99 
100 class ValidationGraphBuilder {
101  public:
102   ValidationGraphBuilder(const Model* main_model,
103                          std::vector<std::string> jpeg_data, float scale,
104                          int64_t zero_point, const Model* validation_model,
105                          const reflection::Schema* schema,
106                          bool use_ondevice_cpu_for_golden);
107 
108   // Builds the part of the model drawn above until the call to the validation
109   // graph. The model is used to generate golden outputs. Calls Finish on the
110   // FlatbufferBuilder.
111   absl::Status BuildIntermediateModel(fb::FlatBufferBuilder* fbb);
112   // Builds the whole model as drawn above. The subgraph_with_golden_outputs
113   // should be the result of invoking subgraph 1 on the output of
114   // BuildIntermediateModel(). Calls Finish on the FlatbufferBuilder.
115   absl::Status BuildFinalModel(fb::FlatBufferBuilder* fbb,
116                                Subgraph* subgraph_with_golden_outputs);
117 
118   ValidationGraphBuilder(const ValidationGraphBuilder&) = delete;
119   ValidationGraphBuilder& operator=(const ValidationGraphBuilder&) = delete;
120 
121  private:
122   static const int32_t kModelVersion = 3;
123   static const int32_t kSkippedIndex = -1;
124   // Operator code numbering.
125   static const int32_t kCallOperatorCode = 0;
126   static const int32_t kDequantizeOperatorCode = 1;
127   static const int32_t kDecodeJpegOperatorCode = 2;
128   // Subgraph numbering.
129   static const int32_t kMainSubgraphIndex = 0;
130   static const int32_t kValidationSubgraphIndex = 2;
131 
132   // Allocation of tensors, for communication between methods that create the
133   // tensors, the operations and the buffers.
134   // (Some of these vectors will always contain only one element, but using the
135   // same type for them simplifies the code a lot).
136   struct TensorInfo {
137     std::vector<int32_t> entrypoint_inputs;
138     std::vector<int32_t> entrypoint_outputs;
139     std::vector<int32_t> jpeg_images;
140 
141     // With float main model, both quantized_images and float_images are set,
142     // and float_images is the same as main input. With a quantized model
143     // only quantized_images is set and it's the same as main input.
144     std::vector<int32_t> quantized_images;
145     std::vector<int32_t> float_images;
146 
147     std::vector<int32_t> main_outputs;  // First half of validation_inputs.
148     std::vector<int32_t> validation_inputs;
149     // With a float model, validation_inputs is used directly. With a quantized
150     // model, the inputs are first dequantized.
151     // Some models have a mixture of quantized outputs that need to be
152     // dequantized to floats; and integer outputs. For integer outputs
153     // kSkippedIndex is used.
154     std::vector<int32_t> dequantized_validation_inputs;
155     std::vector<int32_t> validation_outputs;
156 
157     char* jpeg_buffer_contents = nullptr;
158     int32_t jpeg_buffer_length = -1;
159     int32_t jpeg_height = -1;
160     int32_t jpeg_width = -1;
161 
~TensorInfotflite::acceleration::__anond4c55ce60111::ValidationGraphBuilder::TensorInfo162     ~TensorInfo() { free(jpeg_buffer_contents); }
163   };
164 
MakeModel(bool intermediate_only,Subgraph * subgraph_with_golden_outputs)165   absl::StatusOr<fb::Offset<Model>> MakeModel(
166       bool intermediate_only, Subgraph* subgraph_with_golden_outputs) {
167     TensorInfo tensor_info;
168     auto operator_codes = OperatorCodes();
169     if (!operator_codes.ok()) {
170       return operator_codes.status();
171     }
172     auto subgraphs = SubGraphs(intermediate_only, &tensor_info);
173     if (!subgraphs.ok()) {
174       return subgraphs.status();
175     }
176     auto buffers =
177         Buffers(intermediate_only, tensor_info, subgraph_with_golden_outputs);
178     if (!buffers.ok()) {
179       return buffers.status();
180     }
181     return CreateModel(fbb_, kModelVersion, *operator_codes, *subgraphs,
182                        fbb_.CreateString("validation"), *buffers,
183                        /* metadata_buffer */ 0, /* metadata */ 0,
184                        /* signature_defs */ 0);
185   }
186 
187   absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<OperatorCode>>>>
OperatorCodes()188   OperatorCodes() {
189 #define RET_CHECK_INDEX(constant, code_index)                              \
190   do {                                                                     \
191     if ((constant) != (code_index)) {                                      \
192       return absl::InternalError(absl::StrFormat(                          \
193           "Operator code indexing mismatch %s (%d) != %s (%d)", #constant, \
194           (constant), #code_index, (code_index)));                         \
195     }                                                                      \
196   } while (0)
197     std::vector<fb::Offset<OperatorCode>> codes;
198     RET_CHECK_INDEX(kCallOperatorCode, codes.size());
199     codes.push_back(CreateOperatorCode(fbb_, BuiltinOperator_CUSTOM,
200                                        fbb_.CreateString("validation/call")));
201     RET_CHECK_INDEX(kDequantizeOperatorCode, codes.size());
202     codes.push_back(CreateOperatorCode(fbb_, BuiltinOperator_DEQUANTIZE));
203     RET_CHECK_INDEX(kDecodeJpegOperatorCode, codes.size());
204     codes.push_back(
205         CreateOperatorCode(fbb_, BuiltinOperator_CUSTOM,
206                            fbb_.CreateString("validation/decode_jpeg")));
207     return fbb_.CreateVector(codes);
208 #undef RET_CHECK_INDEX
209   }
210 
Tensors(bool intermediate_only,TensorInfo * tensor_info)211   absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<Tensor>>>> Tensors(
212       bool intermediate_only, TensorInfo* tensor_info) {
213     std::vector<fb::Offset<Tensor>> tensors;
214     int buffer_count = 0;
215 
216     // Copy tensors from a source subgraph, overriding the batch_size where
217     // necessary (the called subgraph always uses batch size 1, the calling
218     // subgraph always uses batch size equal jpeg_data_.size()).
219     auto copy =
220         [&tensors, this, &buffer_count](
221             const SubGraph* from_subgraph, const fb::Vector<int32_t>* indices,
222             std::vector<int32_t>* store_indices_into, int batch_size,
223             const std::string prefix = "",
224             std::function<absl::StatusOr<bool>(const Tensor*, int)> filter =
225                 nullptr) -> absl::Status {
226       int counter = 0;
227       for (auto index = indices->cbegin(); index != indices->cend();
228            index++, counter++) {
229         const Tensor* tensor = from_subgraph->tensors()->Get(*index);
230         if (filter) {
231           auto statusor = filter(tensor, counter);
232           if (!statusor.ok()) {
233             return statusor.status();
234           } else if (!statusor.value()) {
235             store_indices_into->push_back(kSkippedIndex);
236             continue;
237           }
238         }
239         std::vector<int32_t> shape{tensor->shape()->cbegin(),
240                                    tensor->shape()->cend()};
241         if (shape.size() >= 2 && shape[0] == 1 && batch_size > 0) {
242           shape[0] = batch_size;
243         }
244         std::vector<int32_t> shape_signature;
245         if (tensor->shape_signature()) {
246           shape_signature.assign(tensor->shape_signature()->cbegin(),
247                                  tensor->shape_signature()->cend());
248           if (shape_signature.size() >= 2 && shape_signature[0] == 1 &&
249               batch_size > 0) {
250             shape_signature[0] = batch_size;
251           }
252         }
253         auto quantization_parameters = helper_.CopyTable(
254             "tflite.QuantizationParameters", tensor->quantization());
255         if (!quantization_parameters.ok()) {
256           return quantization_parameters.status();
257         }
258         auto sparsity_parameters =
259             helper_.CopyTable("tflite.SparsityParameters", tensor->sparsity());
260         if (!sparsity_parameters.ok()) {
261           return sparsity_parameters.status();
262         }
263         store_indices_into->push_back(tensors.size());
264         std::string name = tensor->name()->str();
265         if (!prefix.empty() && name.find(prefix) != 0) {  // NOLINT
266           name = prefix + name;
267         }
268         tensors.push_back(CreateTensor(
269             fbb_, fbb_.CreateVector(shape), tensor->type(), buffer_count,
270             fbb_.CreateString(name), *quantization_parameters,
271             tensor->is_variable(), *sparsity_parameters,
272             shape_signature.empty() ? 0 : fbb_.CreateVector(shape_signature)));
273         buffer_count++;
274       }
275       return absl::OkStatus();
276     };
277     // Input image, jpeg data.
278     tensor_info->jpeg_images.push_back(tensors.size());
279     DynamicBuffer jpeg_buffer;
280     for (int i = 0; i < jpeg_data_.size(); i++) {
281       jpeg_buffer.AddString(jpeg_data_[i].data(), jpeg_data_[i].size());
282     }
283     tensor_info->jpeg_buffer_length =
284         jpeg_buffer.WriteToBuffer(&(tensor_info->jpeg_buffer_contents));
285     tensors.push_back(CreateTensor(
286         fbb_,
287         fbb_.CreateVector(
288             std::vector<int32_t>{static_cast<int32_t>(jpeg_data_.size())}),
289         TensorType::TensorType_STRING, buffer_count,
290         fbb_.CreateString("call/jpeg_images")));
291     buffer_count++;
292 
293     // Input image.
294     const SubGraph* main_subgraph = main_model_->subgraphs()->Get(0);
295     const Tensor* input_tensor =
296         main_subgraph->tensors()->Get(main_subgraph->inputs()->Get(0));
297     tensor_info->jpeg_height = input_tensor->shape()->Get(1);
298     tensor_info->jpeg_width = input_tensor->shape()->Get(2);
299     if (input_tensor->type() == TensorType_FLOAT32) {
300       // Quantized.
301       std::vector<int32_t> input_shape{input_tensor->shape()->cbegin(),
302                                        input_tensor->shape()->cend()};
303       input_shape[0] = static_cast<int32_t>(jpeg_data_.size());
304       tensor_info->quantized_images.push_back(tensors.size());
305       tensors.push_back(CreateTensor(
306           fbb_, fbb_.CreateVector(input_shape), TensorType::TensorType_UINT8,
307           buffer_count, fbb_.CreateString("call/quant_image"),
308           CreateQuantizationParameters(
309               fbb_, 0, 0, fbb_.CreateVector(std::vector<float>{scale_}),
310               fbb_.CreateVector(std::vector<int64_t>{zero_point_}))));
311       buffer_count++;
312       // Float.
313       tensor_info->float_images.push_back(tensors.size());
314       tensors.push_back(CreateTensor(
315           fbb_, fbb_.CreateVector(input_shape), TensorType::TensorType_FLOAT32,
316           buffer_count, fbb_.CreateString("call/float_image")));
317       buffer_count++;
318     } else {
319       // Quantized only.
320       auto status = copy(main_model_->subgraphs()->Get(0),
321                          main_model_->subgraphs()->Get(0)->inputs(),
322                          &tensor_info->quantized_images, jpeg_data_.size());
323       if (!status.ok()) {
324         return status;
325       }
326     }
327 
328     // Validation inputs, actual.
329     auto status = copy(main_model_->subgraphs()->Get(0),
330                        main_model_->subgraphs()->Get(0)->outputs(),
331                        &tensor_info->main_outputs, jpeg_data_.size());
332     if (!status.ok()) {
333       return status;
334     }
335     if (intermediate_only) {
336       return fbb_.CreateVector(tensors);
337     }
338     // Validation inputs, golden.
339     tensor_info->validation_inputs = tensor_info->main_outputs;
340     status = copy(main_model_->subgraphs()->Get(0),
341                   main_model_->subgraphs()->Get(0)->outputs(),
342                   &tensor_info->validation_inputs, jpeg_data_.size());
343     if (!status.ok()) {
344       return status;
345     }
346     // Entrypoint inputs. Golden first (validator relies on this).
347     for (int i = tensor_info->validation_inputs.size() / 2;
348          i < tensor_info->validation_inputs.size(); i++) {
349       tensor_info->entrypoint_inputs.push_back(
350           tensor_info->validation_inputs[i]);
351     }
352     tensor_info->entrypoint_inputs.push_back(tensor_info->jpeg_images[0]);
353     // Validation inputs, dequantized.
354     status = copy(
355         validation_model_->subgraphs()->Get(0),
356         validation_model_->subgraphs()->Get(0)->inputs(),
357         &tensor_info->dequantized_validation_inputs, jpeg_data_.size(), "",
358         [&tensors, &tensor_info, this](const Tensor* validation_model_input,
359                                        int i) -> absl::StatusOr<bool> {
360           // validation_model_input is the tensor for metrics calculation.
361           // validation_graph_input is the under-construction graph will be
362           // given to the metrics calculation but need to be dequantized first.
363           const Tensor* validation_graph_input = fb::GetTemporaryPointer(
364               fbb_, tensors[tensor_info->validation_inputs[i]]);
365           if (validation_model_input->type() == TensorType_FLOAT32 &&
366               (validation_graph_input->type() == TensorType_UINT8 ||
367                validation_graph_input->type() == TensorType_INT8)) {
368             return true;
369           } else if (validation_model_input->type() !=
370                      validation_graph_input->type()) {
371             const char* name = "(null)";
372             if (validation_model_input->name()) {
373               name = validation_model_input->name()->c_str();
374             }
375             return absl::InvalidArgumentError(
376                 absl::StrFormat("Validation model input %s with type %d is "
377                                 "incompatible with main model output type %d",
378                                 name, validation_model_input->type(),
379                                 validation_graph_input->type()));
380           } else {
381             return false;
382           }
383         });
384     if (!status.ok()) {
385       return status;
386     }
387     // Validation outputs.
388     status = copy(validation_model_->subgraphs()->Get(0),
389                   validation_model_->subgraphs()->Get(0)->outputs(),
390                   &tensor_info->validation_outputs, jpeg_data_.size(),
391                   kMetricPrefix);
392     if (!status.ok()) {
393       return status;
394     }
395     // Outputs from entrypoint graph
396     // Actuals first (validator relies on this);
397     for (int i = 0; i < tensor_info->validation_inputs.size() / 2; i++) {
398       tensor_info->entrypoint_outputs.push_back(
399           tensor_info->validation_inputs[i]);
400     }
401     // Metrics.
402     for (int i = 0; i < tensor_info->validation_outputs.size(); i++) {
403       tensor_info->entrypoint_outputs.push_back(
404           tensor_info->validation_outputs[i]);
405     }
406     return fbb_.CreateVector(tensors);
407   }
408 
409   // Create the options for the custom call op (see call.cc for the options
410   // format).
CallOpCustomOptions(int subgraph)411   fb::Offset<fb::Vector<uint8_t>> CallOpCustomOptions(int subgraph) {
412     flexbuffers::Builder fbb;
413     fbb.Map([&] {
414       fbb.Int("subgraph_index", subgraph);
415       fbb.Int("loop_count", static_cast<int32_t>(jpeg_data_.size()));
416     });
417     fbb.Finish();
418     return fbb_.CreateVector(fbb.GetBuffer());
419   }
420 
421   // Create the options for the custom jpeg op (see decode_jpeg.cc for the
422   // options format).
JpegOpCustomOptions(int height,int width,int channels)423   fb::Offset<fb::Vector<uint8_t>> JpegOpCustomOptions(int height, int width,
424                                                       int channels) {
425     flexbuffers::Builder fbb;
426     fbb.Map([&] {
427       fbb.Int("height", height);
428       fbb.Int("width", width);
429       fbb.Int("channels", channels);
430       fbb.Int("num_images", jpeg_data_.size());
431     });
432     fbb.Finish();
433     return fbb_.CreateVector(fbb.GetBuffer());
434   }
435 
Operators(bool intermediate_only,const TensorInfo & tensor_info)436   fb::Offset<fb::Vector<fb::Offset<Operator>>> Operators(
437       bool intermediate_only, const TensorInfo& tensor_info) {
438     std::vector<fb::Offset<Operator>> ops;
439     // Jpeg decode.
440     ops.push_back(
441         CreateOperator(fbb_, kDecodeJpegOperatorCode,
442                        fbb_.CreateVector(tensor_info.jpeg_images),
443                        fbb_.CreateVector(tensor_info.quantized_images),
444                        tflite::BuiltinOptions_NONE, 0,
445                        JpegOpCustomOptions(tensor_info.jpeg_height,
446                                            tensor_info.jpeg_width, 3)));
447     if (!tensor_info.float_images.empty()) {
448       // Dequantize.
449       ops.push_back(
450           CreateOperator(fbb_, kDequantizeOperatorCode,
451                          fbb_.CreateVector(tensor_info.quantized_images),
452                          fbb_.CreateVector(tensor_info.float_images),
453                          BuiltinOptions_DequantizeOptions, 0));
454       // Call main model.
455       ops.push_back(CreateOperator(fbb_, kCallOperatorCode,
456                                    fbb_.CreateVector(tensor_info.float_images),
457                                    fbb_.CreateVector(tensor_info.main_outputs),
458                                    tflite::BuiltinOptions_NONE, 0,
459                                    CallOpCustomOptions(kMainSubgraphIndex),
460                                    tflite::CustomOptionsFormat_FLEXBUFFERS));
461     } else {
462       // Call main model.
463       ops.push_back(
464           CreateOperator(fbb_, kCallOperatorCode,
465                          fbb_.CreateVector(tensor_info.quantized_images),
466                          fbb_.CreateVector(tensor_info.main_outputs),
467                          tflite::BuiltinOptions_NONE, 0,
468                          CallOpCustomOptions(kMainSubgraphIndex),
469                          tflite::CustomOptionsFormat_FLEXBUFFERS));
470     }
471     if (intermediate_only) {
472       return fbb_.CreateVector(ops);
473     }
474     // Call validation model.
475     std::vector<int32_t> validation_input_indices;
476     for (int i = 0; i < tensor_info.dequantized_validation_inputs.size(); i++) {
477       int32_t validation_input_index;
478       if (tensor_info.dequantized_validation_inputs[i] == kSkippedIndex) {
479         validation_input_index = tensor_info.validation_inputs[i];
480       } else {
481         validation_input_index = tensor_info.dequantized_validation_inputs[i];
482         std::vector<int32_t> dequantize_inputs{
483             tensor_info.validation_inputs[i]};
484         std::vector<int32_t> dequantize_outputs{
485             tensor_info.dequantized_validation_inputs[i]};
486         ops.push_back(CreateOperator(fbb_, kDequantizeOperatorCode,
487                                      fbb_.CreateVector(dequantize_inputs),
488                                      fbb_.CreateVector(dequantize_outputs),
489                                      BuiltinOptions_DequantizeOptions, 0));
490       }
491       validation_input_indices.push_back(validation_input_index);
492     }
493     ops.push_back(CreateOperator(
494         fbb_, kCallOperatorCode, fbb_.CreateVector(validation_input_indices),
495         fbb_.CreateVector(tensor_info.validation_outputs),
496         tflite::BuiltinOptions_NONE, 0,
497         CallOpCustomOptions(kValidationSubgraphIndex),
498         tflite::CustomOptionsFormat_FLEXBUFFERS));
499     return fbb_.CreateVector(ops);
500   }
501 
SubGraphs(bool intermediate_only,TensorInfo * tensor_info)502   absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<SubGraph>>>> SubGraphs(
503       bool intermediate_only, TensorInfo* tensor_info) {
504     auto tensors = Tensors(intermediate_only, tensor_info);
505     if (!tensors.ok()) {
506       return tensors.status();
507     }
508     std::vector<fb::Offset<SubGraph>> graphs;
509     if (intermediate_only) {
510       graphs.push_back(CreateSubGraph(
511           fbb_, *tensors, fbb_.CreateVector(tensor_info->jpeg_images),
512           fbb_.CreateVector(tensor_info->main_outputs),
513           Operators(intermediate_only, *tensor_info),
514           fbb_.CreateString("call")));
515     } else {
516       graphs.push_back(CreateSubGraph(
517           fbb_, *tensors, fbb_.CreateVector(tensor_info->entrypoint_inputs),
518           fbb_.CreateVector(tensor_info->entrypoint_outputs),
519           Operators(intermediate_only, *tensor_info),
520           fbb_.CreateString("call")));
521     }
522     return fbb_.CreateVector(graphs);
523   }
524 
Buffers(bool intermediate_only,const TensorInfo & tensor_info,Subgraph * subgraph_with_golden_outputs)525   absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<Buffer>>>> Buffers(
526       bool intermediate_only, const TensorInfo& tensor_info,
527       Subgraph* subgraph_with_golden_outputs) {
528     std::vector<fb::Offset<Buffer>> buffers;
529 
530     // The buffers created in this method map 1:1 to the tensors created in
531     // Tensors() - a tensor at index X uses buffer at index X. The numbering
532     // is checked along the way using the RET_CHECK_INDEX macro below.
533 #define RET_CHECK_INDEX(tensor_index, buffer_index)                         \
534   do {                                                                      \
535     if ((tensor_index) != (buffer_index)) {                                 \
536       return absl::InternalError(absl::StrFormat(                           \
537           "%s:%d, Tensor/buffer indexing mismatch %s (%d) != %s (%d)",      \
538           __FILE__, __LINE__, #tensor_index, (tensor_index), #buffer_index, \
539           (buffer_index)));                                                 \
540     }                                                                       \
541   } while (0)
542 
543     // Jpeg input.
544     RET_CHECK_INDEX(tensor_info.jpeg_images.size(), 1);
545     RET_CHECK_INDEX(tensor_info.jpeg_images[0], buffers.size());
546     std::vector<uint8_t> jpeg_buffer_vec{
547         reinterpret_cast<const uint8_t*>(tensor_info.jpeg_buffer_contents),
548         reinterpret_cast<const uint8_t*>(tensor_info.jpeg_buffer_contents) +
549             tensor_info.jpeg_buffer_length};
550     buffers.push_back(CreateBuffer(fbb_, fbb_.CreateVector(jpeg_buffer_vec)));
551 
552     // Decoded and dequantized image.
553     RET_CHECK_INDEX(tensor_info.quantized_images.size(), 1);
554     RET_CHECK_INDEX(tensor_info.quantized_images[0], buffers.size());
555     buffers.push_back(CreateBuffer(fbb_));
556     if (!tensor_info.float_images.empty()) {
557       RET_CHECK_INDEX(tensor_info.float_images.size(), 1);
558       RET_CHECK_INDEX(tensor_info.float_images[0], buffers.size());
559       buffers.push_back(CreateBuffer(fbb_));
560     }
561 
562     // Main graph outputs / first half of validation inputs.
563     auto main_subgraph = main_model_->subgraphs()->Get(0);
564     RET_CHECK_INDEX(main_subgraph->outputs()->size(),
565                     tensor_info.main_outputs.size());
566     int main_output_index = 0;
567     int validation_graph_input_index = 0;
568     for (auto i = main_subgraph->outputs()->cbegin();
569          i != main_subgraph->outputs()->cend(); i++) {
570       RET_CHECK_INDEX(tensor_info.main_outputs[main_output_index],
571                       buffers.size());
572       main_output_index++;
573       if (!intermediate_only) {
574         RET_CHECK_INDEX(
575             tensor_info.validation_inputs[validation_graph_input_index],
576             buffers.size());
577         validation_graph_input_index++;
578       }
579       auto t = main_subgraph->tensors()->Get(*i);
580       auto status = helper_.CopyTableToVector(
581           "tflite.Buffer", main_model_->buffers()->Get(t->buffer()), &buffers);
582       if (!status.ok()) {
583         return status;
584       }
585     }
586     if (intermediate_only) {
587       return fbb_.CreateVector(buffers);
588     }
589 
590     // Golden outputs / second half of validation inputs.
591     RET_CHECK_INDEX(tensor_info.validation_inputs.size(),
592                     validation_graph_input_index +
593                         subgraph_with_golden_outputs->outputs().size());
594     for (auto i : subgraph_with_golden_outputs->outputs()) {
595       RET_CHECK_INDEX(
596           tensor_info.validation_inputs[validation_graph_input_index],
597           buffers.size());
598       validation_graph_input_index++;
599       auto t = subgraph_with_golden_outputs->tensor(i);
600       if (!use_ondevice_cpu_for_golden_) {
601         std::vector<uint8_t> output_data{
602             reinterpret_cast<const uint8_t*>(t->data.raw),
603             reinterpret_cast<const uint8_t*>(t->data.raw + t->bytes)};
604         buffers.push_back(CreateBuffer(fbb_, fbb_.CreateVector(output_data)));
605       } else {
606         buffers.push_back(CreateBuffer(fbb_));
607       }
608     }
609 
610     auto validation_model_subgraph = validation_model_->subgraphs()->Get(0);
611     // Dequantized validation inputs.
612     RET_CHECK_INDEX(tensor_info.dequantized_validation_inputs.size(),
613                     validation_model_subgraph->inputs()->size());
614     int validation_graph_dequantized_input_index = 0;
615     for (auto i = validation_model_subgraph->inputs()->cbegin();
616          i != validation_model_subgraph->inputs()->cend(); i++) {
617       if (tensor_info.dequantized_validation_inputs
618               [validation_graph_dequantized_input_index] == kSkippedIndex) {
619         validation_graph_dequantized_input_index++;
620         continue;
621       }
622       RET_CHECK_INDEX(tensor_info.dequantized_validation_inputs
623                           [validation_graph_dequantized_input_index],
624                       buffers.size());
625       validation_graph_dequantized_input_index++;
626       auto t = validation_model_subgraph->tensors()->Get(*i);
627       auto status = helper_.CopyTableToVector(
628           "tflite.Buffer", validation_model_->buffers()->Get(t->buffer()),
629           &buffers);
630       if (!status.ok()) {
631         return status;
632       }
633     }
634 
635     // Validation outputs.
636     RET_CHECK_INDEX(tensor_info.validation_outputs.size(),
637                     validation_model_subgraph->outputs()->size());
638     int validation_graph_output_index = 0;
639     for (auto i = validation_model_subgraph->outputs()->cbegin();
640          i != validation_model_subgraph->outputs()->cend(); i++) {
641       RET_CHECK_INDEX(
642           tensor_info.validation_outputs[validation_graph_output_index],
643           buffers.size());
644       validation_graph_output_index++;
645       auto t = validation_model_subgraph->tensors()->Get(*i);
646       auto status = helper_.CopyTableToVector(
647           "tflite.Buffer", validation_model_->buffers()->Get(t->buffer()),
648           &buffers);
649       if (!status.ok()) {
650         return status;
651       }
652     }
653     return fbb_.CreateVector(buffers);
654 #undef RET_CHECK_INDEX
655   }
656 
657   const Model* main_model_;
658   std::vector<std::string> jpeg_data_;
659   float scale_;
660   int64_t zero_point_;
661   const Model* validation_model_;
662   const reflection::Schema* schema_;
663   fb::FlatBufferBuilder fbb_;
664   FlatbufferHelper helper_;
665   bool use_ondevice_cpu_for_golden_;
666 };
667 
668 // Define the static constant members. Without definition they can be used as
669 // compile-time constants but can not be passed by reference (e.g., used with
670 // absl::StrFormat).
671 const int32_t ValidationGraphBuilder::kModelVersion;
672 const int32_t ValidationGraphBuilder::kSkippedIndex;
673 const int32_t ValidationGraphBuilder::kCallOperatorCode;
674 const int32_t ValidationGraphBuilder::kDequantizeOperatorCode;
675 const int32_t ValidationGraphBuilder::kDecodeJpegOperatorCode;
676 const int32_t ValidationGraphBuilder::kMainSubgraphIndex;
677 const int32_t ValidationGraphBuilder::kValidationSubgraphIndex;
678 
ValidationGraphBuilder(const Model * main_model,std::vector<std::string> jpeg_data,float scale,int64_t zero_point,const Model * validation_model,const reflection::Schema * schema,bool use_ondevice_cpu_for_golden)679 ValidationGraphBuilder::ValidationGraphBuilder(
680     const Model* main_model, std::vector<std::string> jpeg_data, float scale,
681     int64_t zero_point, const Model* validation_model,
682     const reflection::Schema* schema, bool use_ondevice_cpu_for_golden)
683     : main_model_(main_model),
684       jpeg_data_(jpeg_data),
685       scale_(scale),
686       zero_point_(zero_point),
687       validation_model_(validation_model),
688       schema_(schema),
689       helper_(&fbb_, schema_),
690       use_ondevice_cpu_for_golden_(use_ondevice_cpu_for_golden) {}
691 
BuildIntermediateModel(fb::FlatBufferBuilder * fbb)692 absl::Status ValidationGraphBuilder::BuildIntermediateModel(
693     fb::FlatBufferBuilder* fbb) {
694   fbb_.Reset();
695   auto model = MakeModel(/* intermediate_only */ true,
696                          /* subgraph_with_golden_outputs */ nullptr);
697   if (!model.ok()) {
698     return model.status();
699   }
700   fbb_.Finish(*model, "TFL3");
701   std::vector<const Model*> models{main_model_,
702                                    fb::GetRoot<Model>(fbb_.GetBufferPointer())};
703   std::vector<std::string> subgraph_names_not_important(2);
704   return CombineModels(fbb, models, subgraph_names_not_important, schema_);
705 }
706 
BuildFinalModel(fb::FlatBufferBuilder * fbb,Subgraph * subgraph_with_golden_outputs)707 absl::Status ValidationGraphBuilder::BuildFinalModel(
708     fb::FlatBufferBuilder* fbb, Subgraph* subgraph_with_golden_outputs) {
709   fbb_.Reset();
710   auto model =
711       MakeModel(/* intermediate_only */ false, subgraph_with_golden_outputs);
712   if (!model.ok()) {
713     return model.status();
714   }
715   fbb_.Finish(*model, "TFL3");
716   std::vector<const Model*> models{main_model_,
717                                    fb::GetRoot<Model>(fbb_.GetBufferPointer()),
718                                    validation_model_};
719   std::vector<std::string> subgraph_names;
720   auto main_subgraph_name = main_model_->subgraphs()->Get(0)->name();
721   subgraph_names.push_back(main_subgraph_name ? main_subgraph_name->str() : "");
722   subgraph_names.push_back("VALIDATION:main");
723   subgraph_names.push_back("VALIDATION:metrics");
724   return CombineModels(fbb, models, subgraph_names, schema_);
725 }
726 
DescribeShape(const fb::Vector<int32_t> * shape)727 std::string DescribeShape(const fb::Vector<int32_t>* shape) {
728   std::string desc = "[";
729   for (int i = 0; i < shape->size(); i++) {
730     if (i != 0) {
731       desc += ", ";
732     }
733     desc += absl::StrFormat("%d", shape->Get(i));
734   }
735   desc += "]";
736   return desc;
737 }
738 
739 }  // namespace
740 
Embedder(const Model * main_model,const std::vector<std::string> & jpeg_data,float scale,int64_t zero_point,const Model * validation_model,const reflection::Schema * schema,bool use_ondevice_cpu_for_golden)741 Embedder::Embedder(const Model* main_model,
742                    const std::vector<std::string>& jpeg_data, float scale,
743                    int64_t zero_point, const Model* validation_model,
744                    const reflection::Schema* schema,
745                    bool use_ondevice_cpu_for_golden)
746     : main_model_(main_model),
747       jpeg_data_(jpeg_data),
748       scale_(scale),
749       zero_point_(zero_point),
750       validation_model_(validation_model),
751       schema_(schema),
752       use_ondevice_cpu_for_golden_(use_ondevice_cpu_for_golden) {}
753 
ValidateInputs()754 absl::Status Embedder::ValidateInputs() {
755 #define VALIDATE(condition, ...)                                     \
756   if (!(condition)) {                                                \
757     return absl::InvalidArgumentError(absl::StrFormat(__VA_ARGS__)); \
758   }
759   VALIDATE(main_model_, "main_model may not be null");
760   VALIDATE(main_model_->subgraphs()->size(), "main model must have subgraphs");
761   const SubGraph* main_subgraph = main_model_->subgraphs()->Get(0);
762   VALIDATE(main_subgraph->inputs()->size() == 1,
763            "main subgraph must have 1 input (got %d)",
764            main_subgraph->inputs()->size());
765   const auto* shape =
766       main_subgraph->tensors()->Get(main_subgraph->inputs()->Get(0))->shape();
767   VALIDATE(shape->size() == 4,
768            "main subgraph input must have 4 dimensions (got %d)",
769            shape->size());
770   VALIDATE(shape->Get(0) == 1,
771            "main subgraph input must have batch size (got %d)", shape->Get(0));
772   VALIDATE(shape->Get(3) == 1 || shape->Get(3) == 3,
773            "main subgraph input must have 1 or 3 channels (got %d)",
774            shape->Get(3));
775 
776   VALIDATE(!jpeg_data_.empty(), "must have at least 1 jpeg input");
777   int jpeg_number = 0;
778   for (const std::string& jpeg_image_data : jpeg_data_) {
779     int width, height, components;
780     decode_jpeg_kernel::JpegHeader header{0};
781     auto status = decode_jpeg_kernel::ReadJpegHeader(
782         {jpeg_image_data.data(), static_cast<int>(jpeg_image_data.size())},
783         &header);
784     VALIDATE(status.code == kTfLiteOk,
785              "Failed to decompress jpeg data at index %d: %s", jpeg_number,
786              status.error_message.c_str());
787     width = header.width;
788     height = header.height;
789     components = header.channels;
790     VALIDATE(height == shape->Get(1) && width == shape->Get(2) &&
791                  components == shape->Get(3),
792              "Jpeg input at index %d has different size from input tensor "
793              "(jpeg h: %d, w: %d, c: %d; tensor h: %d, w: %d, c: %d)",
794              jpeg_number, height, width, components, shape->Get(1),
795              shape->Get(2), shape->Get(3));
796     jpeg_number++;
797   }
798 
799   int main_output_count = main_subgraph->outputs()->size();
800   VALIDATE(main_output_count > 0,
801            "main subgraph must have at least 1 output (got %d)",
802            main_output_count);
803   VALIDATE(validation_model_->subgraphs()->size(),
804            "validation model must have subgraphs");
805   const SubGraph* validation_subgraph = validation_model_->subgraphs()->Get(0);
806   int validation_input_count = validation_subgraph->inputs()->size();
807   VALIDATE(
808       validation_input_count == main_output_count * 2,
809       "validation subgraph input count must be 2 times main subgraph output "
810       "count (validation output count: %d, main subgraph output count: %d)",
811       validation_input_count, main_output_count);
812   for (int i = 0; i < main_output_count; i++) {
813     auto main_output_tensor =
814         main_subgraph->tensors()->Get(main_subgraph->outputs()->Get(i));
815     auto main_output_shape = DescribeShape(main_output_tensor->shape());
816     VALIDATE(main_output_shape != "[]",
817              "All main outputs must be tensors, %d is a scalar", i);
818     VALIDATE(main_output_tensor->name()->str().find(kMetricPrefix) != 0,
819              "Main output %d name %s clashes with metrics/ prefix", i,
820              main_output_tensor->name()->c_str());
821     auto validation_input_shape_1 =
822         DescribeShape(validation_subgraph->tensors()
823                           ->Get(validation_subgraph->inputs()->Get(i))
824                           ->shape());
825     auto validation_input_shape_2 = DescribeShape(
826         validation_subgraph->tensors()
827             ->Get(validation_subgraph->inputs()->Get(main_output_count + i))
828             ->shape());
829     VALIDATE(main_output_shape == validation_input_shape_1,
830              "Main output %d dimensions %s do not match validation input %d "
831              "dimensions %s",
832              i, main_output_shape, i, validation_input_shape_1);
833     VALIDATE(main_output_shape == validation_input_shape_2,
834              "Main output %d dimensions %s do not match validation input %d "
835              "dimensions %s",
836              i, main_output_shape, main_output_count + i,
837              validation_input_shape_2);
838   }
839   int validation_output_count = validation_subgraph->outputs()->size();
840   VALIDATE(validation_output_count >= 2,
841            "validation output count must be at least 2 (got "
842            "%d)",
843            validation_output_count);
844   bool seen_ok = false;
845   const std::string kOk = "ok";
846   const std::string kPrefixedOk = kMetricPrefix + kOk;
847   std::string names = "";
848   for (int i = 0; i < validation_output_count; i++) {
849     const Tensor* t = validation_subgraph->tensors()->Get(
850         validation_subgraph->outputs()->Get(i));
851     VALIDATE(t->shape()->size(),
852              "validation outputs must be tensors, %d is a scalar", i);
853     seen_ok = (seen_ok || (kOk == t->name()->str()) ||
854                (kPrefixedOk == t->name()->str()));
855     if (i != 0) {
856       names += ", ";
857     }
858     names += t->name()->str();
859   }
860   VALIDATE(seen_ok, "validation must have an output named 'ok' (saw %s)",
861            names);
862 #undef VALIDATE
863   return absl::OkStatus();
864 }
865 
CreateModelWithEmbeddedValidation(fb::FlatBufferBuilder * fbb,ops::builtin::BuiltinOpResolver * resolver)866 absl::Status Embedder::CreateModelWithEmbeddedValidation(
867     fb::FlatBufferBuilder* fbb, ops::builtin::BuiltinOpResolver* resolver) {
868   auto status = ValidateInputs();
869   if (!status.ok()) {
870     return status;
871   }
872   fb::FlatBufferBuilder intermediate_fbb;
873   ValidationGraphBuilder builder(main_model_, jpeg_data_, scale_, zero_point_,
874                                  validation_model_, schema_,
875                                  use_ondevice_cpu_for_golden_);
876   status = builder.BuildIntermediateModel(&intermediate_fbb);
877   if (!status.ok()) {
878     return status;
879   }
880   auto intermediate_model = FlatBufferModel::VerifyAndBuildFromBuffer(
881       reinterpret_cast<const char*>(intermediate_fbb.GetBufferPointer()),
882       intermediate_fbb.GetSize());
883   if (!intermediate_model) {
884     return absl::InternalError("Failed to load intermediate model");
885   }
886   std::unique_ptr<Interpreter> interpreter;
887   InterpreterBuilder(*intermediate_model, *resolver)(&interpreter);
888   if (!interpreter) {
889     return absl::InternalError(
890         "Failed to build interpreter from intermediate model");
891   }
892   Subgraph* subgraph = interpreter->subgraph(1);
893   if (subgraph->AllocateTensors() != kTfLiteOk) {
894     return absl::InternalError(
895         "Failed to AllocateTensors() on validation subgraph of intermediate "
896         "model");
897   }
898   if (subgraph->Invoke() != kTfLiteOk) {
899     return absl::InternalError(
900         "Failed to Invoke() on validation subgraph of intermediate model");
901   }
902   return builder.BuildFinalModel(fbb, subgraph);
903 }
904 
905 }  // namespace acceleration
906 }  // namespace tflite
907