• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/signature/signature_def_util.h"
16 
17 #include <string>
18 
19 #include "absl/memory/memory.h"
20 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
21 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/errors.h"
24 #include "tensorflow/core/protobuf/meta_graph.pb.h"
25 #include "tensorflow/lite/model_builder.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 
28 namespace tflite {
29 namespace {
30 
31 using tensorflow::Status;
32 using SerializedSignatureDefMap = std::map<std::string, std::string>;
33 using SignatureDefMap = std::map<std::string, tensorflow::SignatureDef>;
34 
GetSignatureDefMetadata(const Model * model)35 const Metadata* GetSignatureDefMetadata(const Model* model) {
36   if (!model || !model->metadata()) {
37     return nullptr;
38   }
39   for (int i = 0; i < model->metadata()->size(); ++i) {
40     const Metadata* metadata = model->metadata()->Get(i);
41     if (metadata->name()->str() == kSignatureDefsMetadataName) {
42       return metadata;
43     }
44   }
45   return nullptr;
46 }
47 
ReadSignatureDefMap(const Model * model,const Metadata * metadata,SerializedSignatureDefMap * map)48 Status ReadSignatureDefMap(const Model* model, const Metadata* metadata,
49                            SerializedSignatureDefMap* map) {
50   if (!model || !metadata || !map) {
51     return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
52   }
53   const flatbuffers::Vector<uint8_t>* flatbuffer_data =
54       model->buffers()->Get(metadata->buffer())->data();
55   const auto signature_defs =
56       flexbuffers::GetRoot(flatbuffer_data->data(), flatbuffer_data->size())
57           .AsMap();
58   for (int i = 0; i < signature_defs.Keys().size(); ++i) {
59     const std::string key = signature_defs.Keys()[i].AsString().c_str();
60     (*map)[key] = signature_defs[key].AsString().c_str();
61   }
62   return tensorflow::Status::OK();
63 }
64 
65 }  // namespace
66 
SetSignatureDefMap(const Model * model,const SignatureDefMap & signature_def_map,std::string * model_data_with_signature_def)67 Status SetSignatureDefMap(const Model* model,
68                           const SignatureDefMap& signature_def_map,
69                           std::string* model_data_with_signature_def) {
70   if (!model || !model_data_with_signature_def) {
71     return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
72   }
73   if (signature_def_map.empty()) {
74     return tensorflow::errors::InvalidArgument(
75         "signature_def_map should not be empty");
76   }
77   flexbuffers::Builder fbb;
78   const size_t start_map = fbb.StartMap();
79   auto mutable_model = absl::make_unique<ModelT>();
80   model->UnPackTo(mutable_model.get(), nullptr);
81   int buffer_id = mutable_model->buffers.size();
82   const Metadata* metadata = GetSignatureDefMetadata(model);
83   if (metadata) {
84     buffer_id = metadata->buffer();
85   } else {
86     auto buffer = absl::make_unique<BufferT>();
87     mutable_model->buffers.emplace_back(std::move(buffer));
88     auto sigdef_metadata = absl::make_unique<MetadataT>();
89     sigdef_metadata->buffer = buffer_id;
90     sigdef_metadata->name = kSignatureDefsMetadataName;
91     mutable_model->metadata.emplace_back(std::move(sigdef_metadata));
92   }
93   for (const auto& entry : signature_def_map) {
94     fbb.String(entry.first.c_str(), entry.second.SerializeAsString());
95   }
96   fbb.EndMap(start_map);
97   fbb.Finish();
98   mutable_model->buffers[buffer_id]->data = fbb.GetBuffer();
99   flatbuffers::FlatBufferBuilder builder;
100   auto packed_model = Model::Pack(builder, mutable_model.get());
101   FinishModelBuffer(builder, packed_model);
102   *model_data_with_signature_def =
103       std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
104                   builder.GetSize());
105   return Status::OK();
106 }
107 
HasSignatureDef(const Model * model,const std::string & signature_key)108 bool HasSignatureDef(const Model* model, const std::string& signature_key) {
109   if (!model) {
110     return false;
111   }
112   const Metadata* metadata = GetSignatureDefMetadata(model);
113   if (!metadata) {
114     return false;
115   }
116   SerializedSignatureDefMap signature_defs;
117   if (ReadSignatureDefMap(model, metadata, &signature_defs) !=
118       tensorflow::Status::OK()) {
119     return false;
120   }
121   return (signature_defs.find(signature_key) != signature_defs.end());
122 }
123 
GetSignatureDefMap(const Model * model,SignatureDefMap * signature_def_map)124 Status GetSignatureDefMap(const Model* model,
125                           SignatureDefMap* signature_def_map) {
126   if (!model || !signature_def_map) {
127     return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
128   }
129   SignatureDefMap retrieved_signature_def_map;
130   const Metadata* metadata = GetSignatureDefMetadata(model);
131   if (metadata) {
132     SerializedSignatureDefMap signature_defs;
133     auto status = ReadSignatureDefMap(model, metadata, &signature_defs);
134     if (status != tensorflow::Status::OK()) {
135       return tensorflow::errors::Internal("Error reading signature def map: ",
136                                           status.error_message());
137     }
138     for (const auto& entry : signature_defs) {
139       tensorflow::SignatureDef signature_def;
140       if (!signature_def.ParseFromString(entry.second)) {
141         return tensorflow::errors::Internal(
142             "Cannot parse signature def found in flatbuffer.");
143       }
144       retrieved_signature_def_map[entry.first] = signature_def;
145     }
146     *signature_def_map = retrieved_signature_def_map;
147   }
148   return Status::OK();
149 }
150 
ClearSignatureDefMap(const Model * model,std::string * model_data)151 Status ClearSignatureDefMap(const Model* model, std::string* model_data) {
152   if (!model || !model_data) {
153     return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
154   }
155   auto mutable_model = absl::make_unique<ModelT>();
156   model->UnPackTo(mutable_model.get(), nullptr);
157   for (int id = 0; id < model->metadata()->size(); ++id) {
158     const Metadata* metadata = model->metadata()->Get(id);
159     if (metadata->name()->str() == kSignatureDefsMetadataName) {
160       auto* buffers = &(mutable_model->buffers);
161       buffers->erase(buffers->begin() + metadata->buffer());
162       mutable_model->metadata.erase(mutable_model->metadata.begin() + id);
163       break;
164     }
165   }
166   flatbuffers::FlatBufferBuilder builder;
167   auto packed_model = Model::Pack(builder, mutable_model.get());
168   FinishModelBuffer(builder, packed_model);
169   *model_data =
170       std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
171                   builder.GetSize());
172   return Status::OK();
173 }
174 
175 }  // namespace tflite
176