1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/signature/signature_def_util.h"
16
17 #include <string>
18
19 #include "absl/memory/memory.h"
20 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
21 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/errors.h"
24 #include "tensorflow/core/protobuf/meta_graph.pb.h"
25 #include "tensorflow/lite/model_builder.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27
28 namespace tflite {
29 namespace {
30
31 using tensorflow::Status;
32 using SerializedSignatureDefMap = std::map<std::string, std::string>;
33 using SignatureDefMap = std::map<std::string, tensorflow::SignatureDef>;
34
GetSignatureDefMetadata(const Model * model)35 const Metadata* GetSignatureDefMetadata(const Model* model) {
36 if (!model || !model->metadata()) {
37 return nullptr;
38 }
39 for (int i = 0; i < model->metadata()->size(); ++i) {
40 const Metadata* metadata = model->metadata()->Get(i);
41 if (metadata->name()->str() == kSignatureDefsMetadataName) {
42 return metadata;
43 }
44 }
45 return nullptr;
46 }
47
ReadSignatureDefMap(const Model * model,const Metadata * metadata,SerializedSignatureDefMap * map)48 Status ReadSignatureDefMap(const Model* model, const Metadata* metadata,
49 SerializedSignatureDefMap* map) {
50 if (!model || !metadata || !map) {
51 return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
52 }
53 const flatbuffers::Vector<uint8_t>* flatbuffer_data =
54 model->buffers()->Get(metadata->buffer())->data();
55 const auto signature_defs =
56 flexbuffers::GetRoot(flatbuffer_data->data(), flatbuffer_data->size())
57 .AsMap();
58 for (int i = 0; i < signature_defs.Keys().size(); ++i) {
59 const std::string key = signature_defs.Keys()[i].AsString().c_str();
60 (*map)[key] = signature_defs[key].AsString().c_str();
61 }
62 return tensorflow::Status::OK();
63 }
64
65 } // namespace
66
SetSignatureDefMap(const Model * model,const SignatureDefMap & signature_def_map,std::string * model_data_with_signature_def)67 Status SetSignatureDefMap(const Model* model,
68 const SignatureDefMap& signature_def_map,
69 std::string* model_data_with_signature_def) {
70 if (!model || !model_data_with_signature_def) {
71 return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
72 }
73 if (signature_def_map.empty()) {
74 return tensorflow::errors::InvalidArgument(
75 "signature_def_map should not be empty");
76 }
77 flexbuffers::Builder fbb;
78 const size_t start_map = fbb.StartMap();
79 auto mutable_model = absl::make_unique<ModelT>();
80 model->UnPackTo(mutable_model.get(), nullptr);
81 int buffer_id = mutable_model->buffers.size();
82 const Metadata* metadata = GetSignatureDefMetadata(model);
83 if (metadata) {
84 buffer_id = metadata->buffer();
85 } else {
86 auto buffer = absl::make_unique<BufferT>();
87 mutable_model->buffers.emplace_back(std::move(buffer));
88 auto sigdef_metadata = absl::make_unique<MetadataT>();
89 sigdef_metadata->buffer = buffer_id;
90 sigdef_metadata->name = kSignatureDefsMetadataName;
91 mutable_model->metadata.emplace_back(std::move(sigdef_metadata));
92 }
93 for (const auto& entry : signature_def_map) {
94 fbb.String(entry.first.c_str(), entry.second.SerializeAsString());
95 }
96 fbb.EndMap(start_map);
97 fbb.Finish();
98 mutable_model->buffers[buffer_id]->data = fbb.GetBuffer();
99 flatbuffers::FlatBufferBuilder builder;
100 auto packed_model = Model::Pack(builder, mutable_model.get());
101 FinishModelBuffer(builder, packed_model);
102 *model_data_with_signature_def =
103 std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
104 builder.GetSize());
105 return Status::OK();
106 }
107
HasSignatureDef(const Model * model,const std::string & signature_key)108 bool HasSignatureDef(const Model* model, const std::string& signature_key) {
109 if (!model) {
110 return false;
111 }
112 const Metadata* metadata = GetSignatureDefMetadata(model);
113 if (!metadata) {
114 return false;
115 }
116 SerializedSignatureDefMap signature_defs;
117 if (ReadSignatureDefMap(model, metadata, &signature_defs) !=
118 tensorflow::Status::OK()) {
119 return false;
120 }
121 return (signature_defs.find(signature_key) != signature_defs.end());
122 }
123
GetSignatureDefMap(const Model * model,SignatureDefMap * signature_def_map)124 Status GetSignatureDefMap(const Model* model,
125 SignatureDefMap* signature_def_map) {
126 if (!model || !signature_def_map) {
127 return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
128 }
129 SignatureDefMap retrieved_signature_def_map;
130 const Metadata* metadata = GetSignatureDefMetadata(model);
131 if (metadata) {
132 SerializedSignatureDefMap signature_defs;
133 auto status = ReadSignatureDefMap(model, metadata, &signature_defs);
134 if (status != tensorflow::Status::OK()) {
135 return tensorflow::errors::Internal("Error reading signature def map: ",
136 status.error_message());
137 }
138 for (const auto& entry : signature_defs) {
139 tensorflow::SignatureDef signature_def;
140 if (!signature_def.ParseFromString(entry.second)) {
141 return tensorflow::errors::Internal(
142 "Cannot parse signature def found in flatbuffer.");
143 }
144 retrieved_signature_def_map[entry.first] = signature_def;
145 }
146 *signature_def_map = retrieved_signature_def_map;
147 }
148 return Status::OK();
149 }
150
ClearSignatureDefMap(const Model * model,std::string * model_data)151 Status ClearSignatureDefMap(const Model* model, std::string* model_data) {
152 if (!model || !model_data) {
153 return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
154 }
155 auto mutable_model = absl::make_unique<ModelT>();
156 model->UnPackTo(mutable_model.get(), nullptr);
157 for (int id = 0; id < model->metadata()->size(); ++id) {
158 const Metadata* metadata = model->metadata()->Get(id);
159 if (metadata->name()->str() == kSignatureDefsMetadataName) {
160 auto* buffers = &(mutable_model->buffers);
161 buffers->erase(buffers->begin() + metadata->buffer());
162 mutable_model->metadata.erase(mutable_model->metadata.begin() + id);
163 break;
164 }
165 }
166 flatbuffers::FlatBufferBuilder builder;
167 auto packed_model = Model::Pack(builder, mutable_model.get());
168 FinishModelBuffer(builder, packed_model);
169 *model_data =
170 std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
171 builder.GetSize());
172 return Status::OK();
173 }
174
175 } // namespace tflite
176