• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_EMBEDDER_H_
16 #define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_EMBEDDER_H_
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/status/status.h"
22 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
23 #include "flatbuffers/reflection_generated.h"  // from @flatbuffers
24 #include "tensorflow/lite/kernels/register.h"
25 #include "tensorflow/lite/model.h"
26 #include "tensorflow/lite/schema/reflection/schema_generated.h"
27 namespace tflite {
28 namespace acceleration {
29 // Class to embed a mini-benchmark into a tflite file.
30 //
31 // The inputs are:
32 // - 'main_model': the actual inference graph (e.g., mobilenet classifier)
33 // - 'jpeg_data': jpeg images used as test data.
34 // - 'validation_model': a graph that takes as input two sets of values (the
35 // known-good main model output and the to-be-tested main model output) and
36 // produces 2 or more outputs where one must be called 'ok' (whether the
37 // results are good enough) and rest are metrics that were used to determine
38 // 'ok' and can be used for debugging/telemetry.
39 // (Known good outputs are produced inside this class, i.e. running TFLite CPU
40 // on the build host).
41 //
42 // The output is:
43 // - A new benchmark model which has 3 subgraphs. The 'main_model' subgraph, a
44 // new 'validate' subgraph that invokes the other two subgraphs when required,
45 // and the 'validation_model' subgraph.
46 // - The model output is the output of 'validation_model' + output of
47 // 'main_model'
48 // - This model has additional buffers that store the 'jpeg_data' and the actual
49 // outputs.
50 // - The 'main_model' subgraph is fed the 'jpeg_data' and produces an output
51 // which is used by the 'validation_model' with the known-good outputs to
52 // evaluate the model.
53 // - This entire process is handled end-to-end by the 'validate' subgraph using
54 // two custom ops: 'validate/call' (implemented in :call in this directory) and
55 // 'validate/decode_jpeg' (being implemented).
56 //
57 // Constraints on inputs:
58 // - 'main_model' must have a single input of dimensions
59 //   [1, height, width, 1 or 3]
60 // - the images encoded in 'jpeg_data' must have same height, width and channels
61 //   as 'main_model' input
62 // - the 'validation_model' must have inputs equal to 'main_model' outputs
63 //   duplicated (e.g, if 'main_model' has outputs with dimensions
64 //   [1, 10] and [1, 20]; the 'validation_model' must have inputs with
65 //   dimensions [1, 10], [1, 20], [1, 10], [1, 20]).
66 // - the 'validation_model' must have 2 or more outputs, and one of them must be
67 //   called 'ok'.
68 // - all inputs and outputs must be tensors (not scalars).
69 //
70 // TODO(b/172541832):
71 // - Mark the validation graph so that it's not delegated in the inference case.
72 // - Allow known-good outputs to be given rather than always being calculated
73 // inside this class.
74 class Embedder {
75  public:
76   // Construct Embedder with inputs. The Model* inputs are owned by the caller
77   // and must outlive the Embedder. The `schema` must contain the tflite
78   // flatbuffer schema. If the model is quantized, scale and zero_point are
79   // ignored.
80   Embedder(const Model* main_model, const std::vector<std::string>& jpeg_data,
81            float scale, int64_t zero_point, const Model* validation_model,
82            const reflection::Schema* schema,
83            bool use_ondevice_cpu_for_golden = false);
84   // Construct the output model. Calls Finish() on 'fbb'.
85   // The 'resolver' must have the call and decode_jpeg ops from this directory
86   // registered as 'validation/call' and 'validation/decode_jpeg'.
87   absl::Status CreateModelWithEmbeddedValidation(
88       flatbuffers::FlatBufferBuilder* fbb,
89       ::tflite::ops::builtin::BuiltinOpResolver* resolver);
90   // Check that the inputs fulfill the constraints. Called automatically as part
91   // of CreateModelWithEmbeddedValidation.
92   absl::Status ValidateInputs();
93 
94  private:
95   const Model* main_model_;
96   std::vector<std::string> jpeg_data_;
97   float scale_;
98   int64_t zero_point_;
99   const Model* validation_model_;
100   const reflection::Schema* schema_;
101   bool use_ondevice_cpu_for_golden_;
102 };
103 
104 }  // namespace acceleration
105 }  // namespace tflite
106 
107 #endif  // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_EMBEDDER_H_
108