• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINIBENCHMARK_GRAFTER_H_
16 #define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINIBENCHMARK_GRAFTER_H_
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/status/status.h"
22 #include "absl/status/statusor.h"
23 #include "absl/strings/str_format.h"
24 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
25 #include "flatbuffers/idl.h"  // from @flatbuffers
26 #include "flatbuffers/reflection_generated.h"  // from @flatbuffers
27 #include "tensorflow/lite/model.h"
28 #include "tensorflow/lite/schema/reflection/schema_generated.h"
29 
30 namespace tflite {
31 namespace acceleration {
32 
33 // Combines the given models into one, using the FlatBufferBuilder.
34 //
35 // This is useful for constructing models that contain validation data and
36 // metrics.
37 //
38 // The model fields are handled as follows:
39 // - version is set to 3
40 // - operator codes are concatenated (no deduplication)
41 // - subgraphs are concatenated in order, rewriting operator and buffer indices
42 // to match the combined model. Subgraph names are set from 'subgraph_names'
43 // - description is taken from first model
44 // - buffers are concatenated
45 // - metadata buffer is left unset
46 // - metadata are concatenated
47 // - signature_defs are taken from the first model (as they refer to the main
48 // subgraph).
49 absl::Status CombineModels(flatbuffers::FlatBufferBuilder* fbb,
50                            std::vector<const Model*> models,
51                            std::vector<std::string> subgraph_names,
52                            const reflection::Schema* schema);
53 
54 // Convenience methods for copying flatbuffer Tables and Vectors.
55 //
56 // These are used by CombineModels above, but also needed for constructing
57 // validation subgraphs to be combined with models.
58 class FlatbufferHelper {
59  public:
60   FlatbufferHelper(flatbuffers::FlatBufferBuilder* fbb,
61                    const reflection::Schema* schema);
62   template <typename T>
CopyTableToVector(const std::string & name,const T * o,std::vector<flatbuffers::Offset<T>> * v)63   absl::Status CopyTableToVector(const std::string& name, const T* o,
64                                  std::vector<flatbuffers::Offset<T>>* v) {
65     auto copied = CopyTable(name, o);
66     if (!copied.ok()) {
67       return copied.status();
68     }
69     v->push_back(*copied);
70     return absl::OkStatus();
71   }
72   template <typename T>
CopyTable(const std::string & name,const T * o)73   absl::StatusOr<flatbuffers::Offset<T>> CopyTable(const std::string& name,
74                                                    const T* o) {
75     if (o == nullptr) return 0;
76     const reflection::Object* def = FindObject(name);
77     if (!def) {
78       return absl::NotFoundError(
79           absl::StrFormat("Type %s not found in schema", name));
80     }
81     // We want to use the general copying mechanisms that operate on
82     // flatbuffers::Table pointers. Flatbuffer types are not directly
83     // convertible to Table, as they inherit privately from table.
84     // For type* -> Table*, use reinterpret cast.
85     const flatbuffers::Table* ot =
86         reinterpret_cast<const flatbuffers::Table*>(o);
87     // For Offset<Table *> -> Offset<type>, rely on uoffset_t conversion to
88     // any flatbuffers::Offset<T>.
89     return flatbuffers::CopyTable(*fbb_, *schema_, *def, *ot).o;
90   }
91   template <typename int_type>
CopyIntVector(const flatbuffers::Vector<int_type> * from)92   flatbuffers::Offset<flatbuffers::Vector<int_type>> CopyIntVector(
93       const flatbuffers::Vector<int_type>* from) {
94     if (from == nullptr) {
95       return 0;
96     }
97     std::vector<int_type> v{from->cbegin(), from->cend()};
98     return fbb_->CreateVector(v);
99   }
100   const reflection::Object* FindObject(const std::string& name);
101 
102  private:
103   flatbuffers::FlatBufferBuilder* fbb_;
104   const reflection::Schema* schema_;
105 };
106 
107 }  // namespace acceleration
108 }  // namespace tflite
109 
110 #endif  // THIRD_PARTY_TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINIBENCHMARK_GRAFTER_H_
111