1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINIBENCHMARK_GRAFTER_H_ 16 #define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINIBENCHMARK_GRAFTER_H_ 17 18 #include <string> 19 #include <vector> 20 21 #include "absl/status/status.h" 22 #include "absl/status/statusor.h" 23 #include "absl/strings/str_format.h" 24 #include "flatbuffers/flatbuffers.h" // from @flatbuffers 25 #include "flatbuffers/idl.h" // from @flatbuffers 26 #include "flatbuffers/reflection_generated.h" // from @flatbuffers 27 #include "tensorflow/lite/model.h" 28 #include "tensorflow/lite/schema/reflection/schema_generated.h" 29 30 namespace tflite { 31 namespace acceleration { 32 33 // Combines the given models into one, using the FlatBufferBuilder. 34 // 35 // This is useful for constructing models that contain validation data and 36 // metrics. 37 // 38 // The model fields are handled as follows: 39 // - version is set to 3 40 // - operator codes are concatenated (no deduplication) 41 // - subgraphs are concatenated in order, rewriting operator and buffer indices 42 // to match the combined model. Subgraph names are set from 'subgraph_names' 43 // - description is taken from first model 44 // - buffers are concatenated 45 // - metadata buffer is left unset 46 // - metadata are concatenated 47 // - signature_defs are taken from the first model (as they refer to the main 48 // subgraph). 49 absl::Status CombineModels(flatbuffers::FlatBufferBuilder* fbb, 50 std::vector<const Model*> models, 51 std::vector<std::string> subgraph_names, 52 const reflection::Schema* schema); 53 54 // Convenience methods for copying flatbuffer Tables and Vectors. 55 // 56 // These are used by CombineModels above, but also needed for constructing 57 // validation subgraphs to be combined with models. 58 class FlatbufferHelper { 59 public: 60 FlatbufferHelper(flatbuffers::FlatBufferBuilder* fbb, 61 const reflection::Schema* schema); 62 template <typename T> CopyTableToVector(const std::string & name,const T * o,std::vector<flatbuffers::Offset<T>> * v)63 absl::Status CopyTableToVector(const std::string& name, const T* o, 64 std::vector<flatbuffers::Offset<T>>* v) { 65 auto copied = CopyTable(name, o); 66 if (!copied.ok()) { 67 return copied.status(); 68 } 69 v->push_back(*copied); 70 return absl::OkStatus(); 71 } 72 template <typename T> CopyTable(const std::string & name,const T * o)73 absl::StatusOr<flatbuffers::Offset<T>> CopyTable(const std::string& name, 74 const T* o) { 75 if (o == nullptr) return 0; 76 const reflection::Object* def = FindObject(name); 77 if (!def) { 78 return absl::NotFoundError( 79 absl::StrFormat("Type %s not found in schema", name)); 80 } 81 // We want to use the general copying mechanisms that operate on 82 // flatbuffers::Table pointers. Flatbuffer types are not directly 83 // convertible to Table, as they inherit privately from table. 84 // For type* -> Table*, use reinterpret cast. 85 const flatbuffers::Table* ot = 86 reinterpret_cast<const flatbuffers::Table*>(o); 87 // For Offset<Table *> -> Offset<type>, rely on uoffset_t conversion to 88 // any flatbuffers::Offset<T>. 89 return flatbuffers::CopyTable(*fbb_, *schema_, *def, *ot).o; 90 } 91 template <typename int_type> CopyIntVector(const flatbuffers::Vector<int_type> * from)92 flatbuffers::Offset<flatbuffers::Vector<int_type>> CopyIntVector( 93 const flatbuffers::Vector<int_type>* from) { 94 if (from == nullptr) { 95 return 0; 96 } 97 std::vector<int_type> v{from->cbegin(), from->cend()}; 98 return fbb_->CreateVector(v); 99 } 100 const reflection::Object* FindObject(const std::string& name); 101 102 private: 103 flatbuffers::FlatBufferBuilder* fbb_; 104 const reflection::Schema* schema_; 105 }; 106 107 } // namespace acceleration 108 } // namespace tflite 109 110 #endif // THIRD_PARTY_TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINIBENCHMARK_GRAFTER_H_ 111