1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/grafter.h"
16
17 #include <stdint.h>
18
19 #include <string>
20 #include <vector>
21
22 #include "absl/status/status.h"
23 #include "absl/status/statusor.h"
24 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
25 #include "flatbuffers/idl.h" // from @flatbuffers
26 #include "flatbuffers/reflection.h" // from @flatbuffers
27 #include "flatbuffers/reflection_generated.h" // from @flatbuffers
28 #include "tensorflow/lite/model.h"
29 #include "tensorflow/lite/schema/reflection/schema_generated.h"
30
31 namespace fb = flatbuffers;
32
33 namespace tflite {
34 namespace acceleration {
35
36 namespace {
37
38 class Combiner : FlatbufferHelper {
39 public:
Combiner(flatbuffers::FlatBufferBuilder * fbb,std::vector<const Model * > models,std::vector<std::string> subgraph_names,const reflection::Schema * schema)40 Combiner(flatbuffers::FlatBufferBuilder* fbb,
41 std::vector<const Model*> models,
42 std::vector<std::string> subgraph_names,
43 const reflection::Schema* schema)
44 : FlatbufferHelper(fbb, schema),
45 fbb_(fbb),
46 models_(models),
47 subgraph_names_(subgraph_names),
48 schema_(schema) {}
Combine()49 absl::Status Combine() {
50 auto operator_codes = OperatorCodes();
51 if (!operator_codes.ok()) {
52 return operator_codes.status();
53 }
54 auto subgraphs = SubGraphs();
55 if (!subgraphs.ok()) {
56 return subgraphs.status();
57 }
58 auto buffers = Buffers();
59 if (!buffers.ok()) {
60 return buffers.status();
61 }
62 auto metadata = Metadatas();
63 if (!metadata.ok()) {
64 return metadata.status();
65 }
66 auto signature_defs = SignatureDefs();
67 if (!signature_defs.ok()) {
68 return signature_defs.status();
69 }
70 fb::Offset<Model> model = CreateModel(
71 *fbb_, 3, *operator_codes, *subgraphs,
72 fbb_->CreateString(models_[0]->description()->str()), *buffers,
73 /* metadata_buffer */ 0, *metadata, *signature_defs);
74 fbb_->Finish(model, "TFL3");
75 return absl::OkStatus();
76 }
77
78 private:
79 absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<OperatorCode>>>>
OperatorCodes()80 OperatorCodes() {
81 std::vector<fb::Offset<OperatorCode>> codes;
82 for (const Model* model : models_) {
83 for (int i = 0; i < model->operator_codes()->size(); i++) {
84 auto status = CopyTableToVector(
85 "tflite.OperatorCode", model->operator_codes()->Get(i), &codes);
86 if (!status.ok()) {
87 return status;
88 }
89 }
90 }
91 return fbb_->CreateVector(codes);
92 }
SubGraphs()93 absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<SubGraph>>>> SubGraphs() {
94 std::vector<fb::Offset<SubGraph>> graphs;
95 int buffer_offset = 0;
96 int operator_code_offset = 0;
97 int subgraph_index = 0;
98 for (const Model* model : models_) {
99 if (model->subgraphs()->size() != 1) {
100 return absl::InvalidArgumentError(
101 "Every model to be combined must have a single subgraph.");
102 }
103 auto graph =
104 AdjustSubGraph(model->subgraphs()->Get(0), buffer_offset,
105 operator_code_offset, subgraph_names_[subgraph_index]);
106 if (!graph.ok()) {
107 return graph.status();
108 }
109 graphs.push_back(*graph);
110 buffer_offset += model->buffers()->size();
111 operator_code_offset += model->operator_codes()->size();
112 ++subgraph_index;
113 }
114 return fbb_->CreateVector(graphs);
115 }
Buffers()116 absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<Buffer>>>> Buffers() {
117 std::vector<fb::Offset<Buffer>> buffers;
118 for (const Model* model : models_) {
119 for (int i = 0; i < model->buffers()->size(); i++) {
120 auto status = CopyTableToVector("tflite.Buffer",
121 model->buffers()->Get(i), &buffers);
122 if (!status.ok()) {
123 return status;
124 }
125 }
126 }
127 return fbb_->CreateVector(buffers);
128 }
Metadatas()129 absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<Metadata>>>> Metadatas() {
130 std::vector<fb::Offset<Metadata>> metadatas;
131 int buffer_offset = 0;
132 for (const Model* model : models_) {
133 for (int i = 0; model->metadata() && i < model->metadata()->size(); i++) {
134 auto metadata =
135 AdjustMetadata(model->metadata()->Get(i), buffer_offset);
136 if (!metadata.ok()) {
137 return metadata.status();
138 }
139 metadatas.push_back(*metadata);
140 buffer_offset += model->buffers()->size();
141 }
142 }
143 return fbb_->CreateVector(metadatas);
144 }
145 absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<SignatureDef>>>>
SignatureDefs()146 SignatureDefs() {
147 std::vector<fb::Offset<SignatureDef>> signature_defs;
148 const Model* model = models_[0];
149 for (int i = 0;
150 model->signature_defs() && i < model->signature_defs()->size(); i++) {
151 auto status =
152 CopyTableToVector("tflite.SignatureDef",
153 model->signature_defs()->Get(i), &signature_defs);
154 if (!status.ok()) {
155 return status;
156 }
157 }
158 return fbb_->CreateVector(signature_defs);
159 }
160
AdjustSubGraph(const SubGraph * graph,int buffer_offset,int operator_code_offset,const std::string & name)161 absl::StatusOr<fb::Offset<SubGraph>> AdjustSubGraph(const SubGraph* graph,
162 int buffer_offset,
163 int operator_code_offset,
164 const std::string& name) {
165 auto tensors = AdjustTensors(graph, buffer_offset);
166 if (!tensors.ok()) {
167 return tensors.status();
168 }
169 auto ops = AdjustOps(graph, operator_code_offset);
170 if (!ops.ok()) {
171 return ops.status();
172 }
173 return CreateSubGraph(*fbb_, fbb_->CreateVector(*tensors),
174 CopyIntVector(graph->inputs()),
175 CopyIntVector(graph->outputs()),
176 fbb_->CreateVector(*ops), fbb_->CreateString(name));
177 }
178
AdjustOps(const SubGraph * graph,int operator_code_offset)179 absl::StatusOr<std::vector<fb::Offset<Operator>>> AdjustOps(
180 const SubGraph* graph, int operator_code_offset) {
181 std::vector<fb::Offset<Operator>> ops;
182 auto op_object = FindObject("tflite.Operator");
183 const reflection::Field* builtin_options_field = nullptr;
184 for (auto it = op_object->fields()->cbegin();
185 it != op_object->fields()->cend(); it++) {
186 auto candidate = *it;
187 if (candidate->name()->str() == "builtin_options") {
188 builtin_options_field = candidate;
189 break;
190 }
191 }
192 if (!builtin_options_field) {
193 return absl::UnknownError(
194 "Wasn't able to find the builtin_options field on tflite.Operator");
195 }
196 for (int i = 0; i < graph->operators()->size(); i++) {
197 const Operator* op = graph->operators()->Get(i);
198 fb::Offset<void> copied_builtin_options = 0;
199 if (op->builtin_options() != nullptr) {
200 const fb::Table* opt = (const fb::Table*)op; // NOLINT
201 auto& builtin_options_object = fb::GetUnionType(
202 *schema_, *op_object, *builtin_options_field, *opt);
203 copied_builtin_options =
204 fb::CopyTable(*fbb_, *schema_, builtin_options_object,
205 *fb::GetFieldT(*opt, *builtin_options_field))
206 .o;
207 }
208 ops.push_back(CreateOperator(
209 *fbb_, op->opcode_index() + operator_code_offset,
210 CopyIntVector(op->inputs()), CopyIntVector(op->outputs()),
211 op->builtin_options_type(), copied_builtin_options,
212 CopyIntVector(op->custom_options()), op->custom_options_format(),
213 CopyIntVector(op->mutating_variable_inputs()),
214 CopyIntVector(op->intermediates())));
215 }
216 return ops;
217 }
218
AdjustTensors(const SubGraph * graph,int buffer_offset)219 absl::StatusOr<std::vector<fb::Offset<Tensor>>> AdjustTensors(
220 const SubGraph* graph, int buffer_offset) {
221 std::vector<fb::Offset<Tensor>> tensors;
222 auto orig_tensors = graph->tensors();
223 for (auto iter = orig_tensors->cbegin(); iter != orig_tensors->cend();
224 iter++) {
225 auto i = *iter;
226 std::vector<int32_t> shape{i->shape()->cbegin(), i->shape()->cend()};
227 std::vector<int32_t> shape_signature;
228 if (i->shape_signature()) {
229 shape_signature.assign(i->shape_signature()->cbegin(),
230 i->shape_signature()->cend());
231 }
232 auto quantization =
233 CopyTable("tflite.QuantizationParameters", i->quantization());
234 if (!quantization.ok()) {
235 return quantization.status();
236 }
237 auto sparsity = CopyTable("tflite.SparsityParameters", i->sparsity());
238 if (!sparsity.ok()) {
239 return sparsity.status();
240 }
241 tensors.push_back(CreateTensor(
242 *fbb_, fbb_->CreateVector(shape), i->type(),
243 i->buffer() + buffer_offset, fbb_->CreateString(i->name()->str()),
244 *quantization, i->is_variable(), *sparsity,
245 shape_signature.empty() ? 0 : fbb_->CreateVector(shape_signature)));
246 }
247 return tensors;
248 }
249
AdjustMetadata(const Metadata * metadata,int buffer_offset)250 absl::StatusOr<fb::Offset<Metadata>> AdjustMetadata(const Metadata* metadata,
251 int buffer_offset) {
252 return CreateMetadata(*fbb_,
253 metadata->name()
254 ? fbb_->CreateString(metadata->name()->str())
255 : 0,
256 metadata->buffer())
257 .o;
258 }
259
260 flatbuffers::FlatBufferBuilder* fbb_;
261 std::vector<const Model*> models_;
262 std::vector<std::string> subgraph_names_;
263 const reflection::Schema* schema_;
264 };
265
266 } // namespace
267
CombineModels(flatbuffers::FlatBufferBuilder * fbb,std::vector<const Model * > models,std::vector<std::string> subgraph_names,const reflection::Schema * schema)268 absl::Status CombineModels(flatbuffers::FlatBufferBuilder* fbb,
269 std::vector<const Model*> models,
270 std::vector<std::string> subgraph_names,
271 const reflection::Schema* schema) {
272 if (!fbb || !schema) {
273 return absl::InvalidArgumentError(
274 "Must provide FlatBufferBuilder and Schema");
275 }
276 if (models.size() < 2) {
277 return absl::InvalidArgumentError("Must have 2+ models to combine");
278 }
279 Combiner combiner(fbb, models, subgraph_names, schema);
280 return combiner.Combine();
281 }
282
FlatbufferHelper(flatbuffers::FlatBufferBuilder * fbb,const reflection::Schema * schema)283 FlatbufferHelper::FlatbufferHelper(flatbuffers::FlatBufferBuilder* fbb,
284 const reflection::Schema* schema)
285 : fbb_(fbb), schema_(schema) {}
286
FindObject(const std::string & name)287 const reflection::Object* FlatbufferHelper::FindObject(
288 const std::string& name) {
289 for (auto candidate = schema_->objects()->cbegin();
290 candidate != schema_->objects()->cend(); candidate++) {
291 if (candidate->name()->str() == name) {
292 return *candidate;
293 }
294 }
295 return nullptr;
296 }
297
298 } // namespace acceleration
299 } // namespace tflite
300