1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_EMBEDDER_H_ 16 #define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_EMBEDDER_H_ 17 18 #include <string> 19 #include <vector> 20 21 #include "absl/status/status.h" 22 #include "flatbuffers/flatbuffers.h" // from @flatbuffers 23 #include "flatbuffers/reflection_generated.h" // from @flatbuffers 24 #include "tensorflow/lite/kernels/register.h" 25 #include "tensorflow/lite/model.h" 26 #include "tensorflow/lite/schema/reflection/schema_generated.h" 27 namespace tflite { 28 namespace acceleration { 29 // Class to embed a mini-benchmark into a tflite file. 30 // 31 // The inputs are: 32 // - 'main_model': the actual inference graph (e.g., mobilenet classifier) 33 // - 'jpeg_data': jpeg images used as test data. 34 // - 'validation_model': a graph that takes as input two sets of values (the 35 // known-good main model output and the to-be-tested main model output) and 36 // produces 2 or more outputs where one must be called 'ok' (whether the 37 // results are good enough) and rest are metrics that were used to determine 38 // 'ok' and can be used for debugging/telemetry. 39 // (Known good outputs are produced inside this class, i.e. running TFLite CPU 40 // on the build host). 41 // 42 // The output is: 43 // - A new benchmark model which has 3 subgraphs. The 'main_model' subgraph, a 44 // new 'validate' subgraph that invokes the other two subgraphs when required, 45 // and the 'validation_model' subgraph. 46 // - The model output is the output of 'validation_model' + output of 47 // 'main_model' 48 // - This model has additional buffers that store the 'jpeg_data' and the actual 49 // outputs. 50 // - The 'main_model' subgraph is fed the 'jpeg_data' and produces an output 51 // which is used by the 'validation_model' with the known-good outputs to 52 // evaluate the model. 53 // - This entire process is handled end-to-end by the 'validate' subgraph using 54 // two custom ops: 'validate/call' (implemented in :call in this directory) and 55 // 'validate/decode_jpeg' (being implemented). 56 // 57 // Constraints on inputs: 58 // - 'main_model' must have a single input of dimensions 59 // [1, height, width, 1 or 3] 60 // - the images encoded in 'jpeg_data' must have same height, width and channels 61 // as 'main_model' input 62 // - the 'validation_model' must have inputs equal to 'main_model' outputs 63 // duplicated (e.g, if 'main_model' has outputs with dimensions 64 // [1, 10] and [1, 20]; the 'validation_model' must have inputs with 65 // dimensions [1, 10], [1, 20], [1, 10], [1, 20]). 66 // - the 'validation_model' must have 2 or more outputs, and one of them must be 67 // called 'ok'. 68 // - all inputs and outputs must be tensors (not scalars). 69 // 70 // TODO(b/172541832): 71 // - Mark the validation graph so that it's not delegated in the inference case. 72 // - Allow known-good outputs to be given rather than always being calculated 73 // inside this class. 74 class Embedder { 75 public: 76 // Construct Embedder with inputs. The Model* inputs are owned by the caller 77 // and must outlive the Embedder. The `schema` must contain the tflite 78 // flatbuffer schema. If the model is quantized, scale and zero_point are 79 // ignored. 80 Embedder(const Model* main_model, const std::vector<std::string>& jpeg_data, 81 float scale, int64_t zero_point, const Model* validation_model, 82 const reflection::Schema* schema, 83 bool use_ondevice_cpu_for_golden = false); 84 // Construct the output model. Calls Finish() on 'fbb'. 85 // The 'resolver' must have the call and decode_jpeg ops from this directory 86 // registered as 'validation/call' and 'validation/decode_jpeg'. 87 absl::Status CreateModelWithEmbeddedValidation( 88 flatbuffers::FlatBufferBuilder* fbb, 89 ::tflite::ops::builtin::BuiltinOpResolver* resolver); 90 // Check that the inputs fulfill the constraints. Called automatically as part 91 // of CreateModelWithEmbeddedValidation. 92 absl::Status ValidateInputs(); 93 94 private: 95 const Model* main_model_; 96 std::vector<std::string> jpeg_data_; 97 float scale_; 98 int64_t zero_point_; 99 const Model* validation_model_; 100 const reflection::Schema* schema_; 101 bool use_ondevice_cpu_for_golden_; 102 }; 103 104 } // namespace acceleration 105 } // namespace tflite 106 107 #endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_EMBEDDER_H_ 108