• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/grafter.h"
16 
17 #include <stdint.h>
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/status/status.h"
23 #include "absl/status/statusor.h"
24 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
25 #include "flatbuffers/idl.h"  // from @flatbuffers
26 #include "flatbuffers/reflection.h"  // from @flatbuffers
27 #include "flatbuffers/reflection_generated.h"  // from @flatbuffers
28 #include "tensorflow/lite/model.h"
29 #include "tensorflow/lite/schema/reflection/schema_generated.h"
30 
31 namespace fb = flatbuffers;
32 
33 namespace tflite {
34 namespace acceleration {
35 
36 namespace {
37 
38 class Combiner : FlatbufferHelper {
39  public:
Combiner(flatbuffers::FlatBufferBuilder * fbb,std::vector<const Model * > models,std::vector<std::string> subgraph_names,const reflection::Schema * schema)40   Combiner(flatbuffers::FlatBufferBuilder* fbb,
41            std::vector<const Model*> models,
42            std::vector<std::string> subgraph_names,
43            const reflection::Schema* schema)
44       : FlatbufferHelper(fbb, schema),
45         fbb_(fbb),
46         models_(models),
47         subgraph_names_(subgraph_names),
48         schema_(schema) {}
Combine()49   absl::Status Combine() {
50     auto operator_codes = OperatorCodes();
51     if (!operator_codes.ok()) {
52       return operator_codes.status();
53     }
54     auto subgraphs = SubGraphs();
55     if (!subgraphs.ok()) {
56       return subgraphs.status();
57     }
58     auto buffers = Buffers();
59     if (!buffers.ok()) {
60       return buffers.status();
61     }
62     auto metadata = Metadatas();
63     if (!metadata.ok()) {
64       return metadata.status();
65     }
66     auto signature_defs = SignatureDefs();
67     if (!signature_defs.ok()) {
68       return signature_defs.status();
69     }
70     fb::Offset<Model> model = CreateModel(
71         *fbb_, 3, *operator_codes, *subgraphs,
72         fbb_->CreateString(models_[0]->description()->str()), *buffers,
73         /* metadata_buffer */ 0, *metadata, *signature_defs);
74     fbb_->Finish(model, "TFL3");
75     return absl::OkStatus();
76   }
77 
78  private:
79   absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<OperatorCode>>>>
OperatorCodes()80   OperatorCodes() {
81     std::vector<fb::Offset<OperatorCode>> codes;
82     for (const Model* model : models_) {
83       for (int i = 0; i < model->operator_codes()->size(); i++) {
84         auto status = CopyTableToVector(
85             "tflite.OperatorCode", model->operator_codes()->Get(i), &codes);
86         if (!status.ok()) {
87           return status;
88         }
89       }
90     }
91     return fbb_->CreateVector(codes);
92   }
SubGraphs()93   absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<SubGraph>>>> SubGraphs() {
94     std::vector<fb::Offset<SubGraph>> graphs;
95     int buffer_offset = 0;
96     int operator_code_offset = 0;
97     int subgraph_index = 0;
98     for (const Model* model : models_) {
99       if (model->subgraphs()->size() != 1) {
100         return absl::InvalidArgumentError(
101             "Every model to be combined must have a single subgraph.");
102       }
103       auto graph =
104           AdjustSubGraph(model->subgraphs()->Get(0), buffer_offset,
105                          operator_code_offset, subgraph_names_[subgraph_index]);
106       if (!graph.ok()) {
107         return graph.status();
108       }
109       graphs.push_back(*graph);
110       buffer_offset += model->buffers()->size();
111       operator_code_offset += model->operator_codes()->size();
112       ++subgraph_index;
113     }
114     return fbb_->CreateVector(graphs);
115   }
Buffers()116   absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<Buffer>>>> Buffers() {
117     std::vector<fb::Offset<Buffer>> buffers;
118     for (const Model* model : models_) {
119       for (int i = 0; i < model->buffers()->size(); i++) {
120         auto status = CopyTableToVector("tflite.Buffer",
121                                         model->buffers()->Get(i), &buffers);
122         if (!status.ok()) {
123           return status;
124         }
125       }
126     }
127     return fbb_->CreateVector(buffers);
128   }
Metadatas()129   absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<Metadata>>>> Metadatas() {
130     std::vector<fb::Offset<Metadata>> metadatas;
131     int buffer_offset = 0;
132     for (const Model* model : models_) {
133       for (int i = 0; model->metadata() && i < model->metadata()->size(); i++) {
134         auto metadata =
135             AdjustMetadata(model->metadata()->Get(i), buffer_offset);
136         if (!metadata.ok()) {
137           return metadata.status();
138         }
139         metadatas.push_back(*metadata);
140         buffer_offset += model->buffers()->size();
141       }
142     }
143     return fbb_->CreateVector(metadatas);
144   }
145   absl::StatusOr<fb::Offset<fb::Vector<fb::Offset<SignatureDef>>>>
SignatureDefs()146   SignatureDefs() {
147     std::vector<fb::Offset<SignatureDef>> signature_defs;
148     const Model* model = models_[0];
149     for (int i = 0;
150          model->signature_defs() && i < model->signature_defs()->size(); i++) {
151       auto status =
152           CopyTableToVector("tflite.SignatureDef",
153                             model->signature_defs()->Get(i), &signature_defs);
154       if (!status.ok()) {
155         return status;
156       }
157     }
158     return fbb_->CreateVector(signature_defs);
159   }
160 
AdjustSubGraph(const SubGraph * graph,int buffer_offset,int operator_code_offset,const std::string & name)161   absl::StatusOr<fb::Offset<SubGraph>> AdjustSubGraph(const SubGraph* graph,
162                                                       int buffer_offset,
163                                                       int operator_code_offset,
164                                                       const std::string& name) {
165     auto tensors = AdjustTensors(graph, buffer_offset);
166     if (!tensors.ok()) {
167       return tensors.status();
168     }
169     auto ops = AdjustOps(graph, operator_code_offset);
170     if (!ops.ok()) {
171       return ops.status();
172     }
173     return CreateSubGraph(*fbb_, fbb_->CreateVector(*tensors),
174                           CopyIntVector(graph->inputs()),
175                           CopyIntVector(graph->outputs()),
176                           fbb_->CreateVector(*ops), fbb_->CreateString(name));
177   }
178 
AdjustOps(const SubGraph * graph,int operator_code_offset)179   absl::StatusOr<std::vector<fb::Offset<Operator>>> AdjustOps(
180       const SubGraph* graph, int operator_code_offset) {
181     std::vector<fb::Offset<Operator>> ops;
182     auto op_object = FindObject("tflite.Operator");
183     const reflection::Field* builtin_options_field = nullptr;
184     for (auto it = op_object->fields()->cbegin();
185          it != op_object->fields()->cend(); it++) {
186       auto candidate = *it;
187       if (candidate->name()->str() == "builtin_options") {
188         builtin_options_field = candidate;
189         break;
190       }
191     }
192     if (!builtin_options_field) {
193       return absl::UnknownError(
194           "Wasn't able to find the builtin_options field on tflite.Operator");
195     }
196     for (int i = 0; i < graph->operators()->size(); i++) {
197       const Operator* op = graph->operators()->Get(i);
198       fb::Offset<void> copied_builtin_options = 0;
199       if (op->builtin_options() != nullptr) {
200         const fb::Table* opt = (const fb::Table*)op;  // NOLINT
201         auto& builtin_options_object = fb::GetUnionType(
202             *schema_, *op_object, *builtin_options_field, *opt);
203         copied_builtin_options =
204             fb::CopyTable(*fbb_, *schema_, builtin_options_object,
205                           *fb::GetFieldT(*opt, *builtin_options_field))
206                 .o;
207       }
208       ops.push_back(CreateOperator(
209           *fbb_, op->opcode_index() + operator_code_offset,
210           CopyIntVector(op->inputs()), CopyIntVector(op->outputs()),
211           op->builtin_options_type(), copied_builtin_options,
212           CopyIntVector(op->custom_options()), op->custom_options_format(),
213           CopyIntVector(op->mutating_variable_inputs()),
214           CopyIntVector(op->intermediates())));
215     }
216     return ops;
217   }
218 
AdjustTensors(const SubGraph * graph,int buffer_offset)219   absl::StatusOr<std::vector<fb::Offset<Tensor>>> AdjustTensors(
220       const SubGraph* graph, int buffer_offset) {
221     std::vector<fb::Offset<Tensor>> tensors;
222     auto orig_tensors = graph->tensors();
223     for (auto iter = orig_tensors->cbegin(); iter != orig_tensors->cend();
224          iter++) {
225       auto i = *iter;
226       std::vector<int32_t> shape{i->shape()->cbegin(), i->shape()->cend()};
227       std::vector<int32_t> shape_signature;
228       if (i->shape_signature()) {
229         shape_signature.assign(i->shape_signature()->cbegin(),
230                                i->shape_signature()->cend());
231       }
232       auto quantization =
233           CopyTable("tflite.QuantizationParameters", i->quantization());
234       if (!quantization.ok()) {
235         return quantization.status();
236       }
237       auto sparsity = CopyTable("tflite.SparsityParameters", i->sparsity());
238       if (!sparsity.ok()) {
239         return sparsity.status();
240       }
241       tensors.push_back(CreateTensor(
242           *fbb_, fbb_->CreateVector(shape), i->type(),
243           i->buffer() + buffer_offset, fbb_->CreateString(i->name()->str()),
244           *quantization, i->is_variable(), *sparsity,
245           shape_signature.empty() ? 0 : fbb_->CreateVector(shape_signature)));
246     }
247     return tensors;
248   }
249 
AdjustMetadata(const Metadata * metadata,int buffer_offset)250   absl::StatusOr<fb::Offset<Metadata>> AdjustMetadata(const Metadata* metadata,
251                                                       int buffer_offset) {
252     return CreateMetadata(*fbb_,
253                           metadata->name()
254                               ? fbb_->CreateString(metadata->name()->str())
255                               : 0,
256                           metadata->buffer())
257         .o;
258   }
259 
260   flatbuffers::FlatBufferBuilder* fbb_;
261   std::vector<const Model*> models_;
262   std::vector<std::string> subgraph_names_;
263   const reflection::Schema* schema_;
264 };
265 
266 }  // namespace
267 
CombineModels(flatbuffers::FlatBufferBuilder * fbb,std::vector<const Model * > models,std::vector<std::string> subgraph_names,const reflection::Schema * schema)268 absl::Status CombineModels(flatbuffers::FlatBufferBuilder* fbb,
269                            std::vector<const Model*> models,
270                            std::vector<std::string> subgraph_names,
271                            const reflection::Schema* schema) {
272   if (!fbb || !schema) {
273     return absl::InvalidArgumentError(
274         "Must provide FlatBufferBuilder and Schema");
275   }
276   if (models.size() < 2) {
277     return absl::InvalidArgumentError("Must have 2+ models to combine");
278   }
279   Combiner combiner(fbb, models, subgraph_names, schema);
280   return combiner.Combine();
281 }
282 
FlatbufferHelper(flatbuffers::FlatBufferBuilder * fbb,const reflection::Schema * schema)283 FlatbufferHelper::FlatbufferHelper(flatbuffers::FlatBufferBuilder* fbb,
284                                    const reflection::Schema* schema)
285     : fbb_(fbb), schema_(schema) {}
286 
FindObject(const std::string & name)287 const reflection::Object* FlatbufferHelper::FindObject(
288     const std::string& name) {
289   for (auto candidate = schema_->objects()->cbegin();
290        candidate != schema_->objects()->cend(); candidate++) {
291     if (candidate->name()->str() == name) {
292       return *candidate;
293     }
294   }
295   return nullptr;
296 }
297 
298 }  // namespace acceleration
299 }  // namespace tflite
300