• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/serialization.h"
16 
17 #if defined(_WIN32)
18 #include <fstream>
19 #include <iostream>
20 #else
21 #include <errno.h>
22 #include <fcntl.h>
23 #include <sys/file.h>
24 #include <unistd.h>
25 
26 #include <cstring>
27 #endif  // defined(_WIN32)
28 
29 #include <time.h>
30 
31 #include <algorithm>
32 #include <cstdint>
33 #include <memory>
34 #include <string>
35 #include <vector>
36 
37 #include "tensorflow/lite/c/common.h"
38 #include "tensorflow/lite/minimal_logging.h"
39 #include "utils/hash/farmhash.h"
40 
41 namespace tflite {
42 namespace delegates {
43 namespace {
44 
45 static const char kDelegatedNodesSuffix[] = "_dnodes";
46 
47 // Farmhash Fingerprint
CombineFingerprints(uint64_t l,uint64_t h)48 inline uint64_t CombineFingerprints(uint64_t l, uint64_t h) {
49   // Murmur-inspired hashing.
50   const uint64_t kMul = 0x9ddfea08eb382d69ULL;
51   uint64_t a = (l ^ h) * kMul;
52   a ^= (a >> 47);
53   uint64_t b = (h ^ a) * kMul;
54   b ^= (b >> 44);
55   b *= kMul;
56   b ^= (b >> 41);
57   b *= kMul;
58   return b;
59 }
60 
JoinPath(const std::string & path1,const std::string & path2)61 inline std::string JoinPath(const std::string& path1,
62                             const std::string& path2) {
63   return (path1.back() == '/') ? (path1 + path2) : (path1 + "/" + path2);
64 }
65 
GetFilePath(const std::string & cache_dir,const std::string & model_token,const uint64_t fingerprint)66 inline std::string GetFilePath(const std::string& cache_dir,
67                                const std::string& model_token,
68                                const uint64_t fingerprint) {
69   auto file_name = (model_token + "_" + std::to_string(fingerprint) + ".bin");
70   return JoinPath(cache_dir, file_name);
71 }
72 
73 }  // namespace
74 
StrFingerprint(const void * data,const size_t num_bytes)75 std::string StrFingerprint(const void* data, const size_t num_bytes) {
76   return std::to_string(
77       ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(
78           reinterpret_cast<const char*>(data), num_bytes));
79 }
80 
SerializationEntry(const std::string & cache_dir,const std::string & model_token,const uint64_t fingerprint)81 SerializationEntry::SerializationEntry(const std::string& cache_dir,
82                                        const std::string& model_token,
83                                        const uint64_t fingerprint)
84     : cache_dir_(cache_dir),
85       model_token_(model_token),
86       fingerprint_(fingerprint) {}
87 
SetData(TfLiteContext * context,const char * data,const size_t size) const88 TfLiteStatus SerializationEntry::SetData(TfLiteContext* context,
89                                          const char* data,
90                                          const size_t size) const {
91   auto filepath = GetFilePath(cache_dir_, model_token_, fingerprint_);
92   // Temporary file to write data to.
93   const std::string temp_filepath =
94       JoinPath(cache_dir_, (model_token_ + std::to_string(fingerprint_) +
95                             std::to_string(time(nullptr))));
96 
97 #if defined(_WIN32)
98   std::ofstream out_file(temp_filepath.c_str());
99   if (!out_file) {
100     TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "Could not create file: %s",
101                     temp_filepath.c_str());
102     return kTfLiteDelegateDataWriteError;
103   }
104   out_file.write(data, size);
105   out_file.flush();
106   out_file.close();
107   // rename is an atomic operation in most systems.
108   if (rename(temp_filepath.c_str(), filepath.c_str()) < 0) {
109     TF_LITE_KERNEL_LOG(context, "Failed to rename to %s", filepath.c_str());
110     return kTfLiteDelegateDataWriteError;
111   }
112 #else   // !defined(_WIN32)
113   // This method only works on unix/POSIX systems.
114   const int fd = open(temp_filepath.c_str(),
115                       O_WRONLY | O_APPEND | O_CREAT | O_CLOEXEC, 0600);
116   if (fd < 0) {
117     TF_LITE_KERNEL_LOG(context, "Failed to open for writing: %s",
118                        temp_filepath.c_str());
119     return kTfLiteDelegateDataWriteError;
120   }
121   // Loop until all bytes written.
122   ssize_t len = 0;
123   const char* buf = data;
124   do {
125     ssize_t ret = write(fd, buf, size);
126     if (ret <= 0) {
127       close(fd);
128       TF_LITE_KERNEL_LOG(context, "Failed to write data to: %s, error: %s",
129                          temp_filepath.c_str(), std::strerror(errno));
130       return kTfLiteDelegateDataWriteError;
131     }
132 
133     len += ret;
134     buf += ret;
135   } while (len < static_cast<ssize_t>(size));
136   // Use fsync to ensure data is on disk before renaming temp file.
137   if (fsync(fd) < 0) {
138     close(fd);
139     TF_LITE_KERNEL_LOG(context, "Could not fsync: %s, error: %s",
140                        temp_filepath.c_str(), std::strerror(errno));
141     return kTfLiteDelegateDataWriteError;
142   }
143   if (close(fd) < 0) {
144     TF_LITE_KERNEL_LOG(context, "Could not close fd: %s, error: %s",
145                        temp_filepath.c_str(), std::strerror(errno));
146     return kTfLiteDelegateDataWriteError;
147   }
148   if (rename(temp_filepath.c_str(), filepath.c_str()) < 0) {
149     TF_LITE_KERNEL_LOG(context, "Failed to rename to %s, error: %s",
150                        filepath.c_str(), std::strerror(errno));
151     return kTfLiteDelegateDataWriteError;
152   }
153 #endif  // defined(_WIN32)
154 
155   TFLITE_LOG(TFLITE_LOG_INFO, "Wrote serialized data for model %s (%d B) to %s",
156              model_token_.c_str(), size, filepath.c_str());
157 
158   return kTfLiteOk;
159 }
160 
GetData(TfLiteContext * context,std::string * data) const161 TfLiteStatus SerializationEntry::GetData(TfLiteContext* context,
162                                          std::string* data) const {
163   if (!data) return kTfLiteError;
164   auto filepath = GetFilePath(cache_dir_, model_token_, fingerprint_);
165 
166 #if defined(_WIN32)
167   std::ifstream cache_stream(filepath,
168                              std::ios_base::in | std::ios_base::binary);
169   if (cache_stream.good()) {
170     cache_stream.seekg(0, cache_stream.end);
171     int cache_size = cache_stream.tellg();
172     cache_stream.seekg(0, cache_stream.beg);
173 
174     data->resize(cache_size);
175     cache_stream.read(&(*data)[0], cache_size);
176     cache_stream.close();
177   }
178 #else   // !defined(_WIN32)
179   // This method only works on unix/POSIX systems, but is more optimized & has
180   // lower size overhead for Android binaries.
181   data->clear();
182   // O_CLOEXEC is needed for correctness, as another thread may call
183   // popen() and the callee inherit the lock if it's not O_CLOEXEC.
184   int fd = open(filepath.c_str(), O_RDONLY | O_CLOEXEC, 0600);
185   if (fd < 0) {
186     TF_LITE_KERNEL_LOG(context, "File %s couldn't be opened for reading: %s",
187                        filepath.c_str(), std::strerror(errno));
188     return kTfLiteDelegateDataNotFound;
189   }
190   int lock_status = flock(fd, LOCK_EX);
191   if (lock_status < 0) {
192     close(fd);
193     TF_LITE_KERNEL_LOG(context, "Could not flock %s: %s", filepath.c_str(),
194                        std::strerror(errno));
195     return kTfLiteDelegateDataReadError;
196   }
197   char buffer[512];
198   while (true) {
199     int bytes_read = read(fd, buffer, 512);
200     if (bytes_read == 0) {
201       // EOF
202       close(fd);
203       return kTfLiteOk;
204     } else if (bytes_read < 0) {
205       close(fd);
206       TF_LITE_KERNEL_LOG(context, "Error reading %s: %s", filepath.c_str(),
207                          std::strerror(errno));
208       return kTfLiteDelegateDataReadError;
209     } else {
210       data->append(buffer, bytes_read);
211     }
212   }
213 #endif  // defined(_WIN32)
214 
215   TFLITE_LOG_PROD(TFLITE_LOG_INFO,
216                   "Found serialized data for model %s (%d B) at %s",
217                   model_token_.c_str(), data->size(), filepath.c_str());
218 
219   if (!data->empty()) {
220     TFLITE_LOG(TFLITE_LOG_INFO, "Data found at %s: %d bytes", filepath.c_str(),
221                data->size());
222     return kTfLiteOk;
223   } else {
224     TF_LITE_KERNEL_LOG(context, "No serialized data found: %s",
225                        filepath.c_str());
226     return kTfLiteDelegateDataNotFound;
227   }
228 }
229 
GetEntryImpl(const std::string & custom_key,TfLiteContext * context,const TfLiteDelegateParams * delegate_params)230 SerializationEntry Serialization::GetEntryImpl(
231     const std::string& custom_key, TfLiteContext* context,
232     const TfLiteDelegateParams* delegate_params) {
233   // First incorporate model_token.
234   // We use Fingerprint64 instead of std::hash, since the latter isn't
235   // guaranteed to be stable across runs. See b/172237993.
236   uint64_t fingerprint =
237       ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(
238           model_token_.c_str(), model_token_.size());
239 
240   // Incorporate custom_key.
241   const uint64_t custom_str_fingerprint =
242       ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(
243           custom_key.c_str(), custom_key.size());
244   fingerprint = CombineFingerprints(fingerprint, custom_str_fingerprint);
245 
246   // Incorporate context details, if provided.
247   // A quick heuristic involving graph tensors to 'fingerprint' a
248   // tflite::Subgraph. We don't consider the execution plan, since it could be
249   // in flux if the delegate uses this method during
250   // ReplaceNodeSubsetsWithDelegateKernels (eg in kernel Init).
251   if (context) {
252     std::vector<int32_t> context_data;
253     // Number of tensors can be large.
254     const int tensors_to_consider = std::min<int>(context->tensors_size, 100);
255     context_data.reserve(1 + tensors_to_consider);
256     context_data.push_back(context->tensors_size);
257     for (int i = 0; i < tensors_to_consider; ++i) {
258       context_data.push_back(context->tensors[i].bytes);
259     }
260     const uint64_t context_fingerprint =
261         ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(
262             reinterpret_cast<char*>(context_data.data()),
263                                 context_data.size() * sizeof(int32_t));
264     fingerprint = CombineFingerprints(fingerprint, context_fingerprint);
265   }
266 
267   // Incorporate delegated partition details, if provided.
268   // A quick heuristic that considers the nodes & I/O tensor sizes to
269   // fingerprint TfLiteDelegateParams.
270   if (delegate_params) {
271     std::vector<int32_t> partition_data;
272     auto* nodes = delegate_params->nodes_to_replace;
273     auto* input_tensors = delegate_params->input_tensors;
274     auto* output_tensors = delegate_params->output_tensors;
275     partition_data.reserve(nodes->size + input_tensors->size +
276                            output_tensors->size);
277     partition_data.insert(partition_data.end(), nodes->data,
278                           nodes->data + nodes->size);
279     for (int i = 0; i < input_tensors->size; ++i) {
280       auto& tensor = context->tensors[input_tensors->data[i]];
281       partition_data.push_back(tensor.bytes);
282     }
283     for (int i = 0; i < output_tensors->size; ++i) {
284       auto& tensor = context->tensors[output_tensors->data[i]];
285       partition_data.push_back(tensor.bytes);
286     }
287     const uint64_t partition_fingerprint =
288         ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(
289             reinterpret_cast<char*>(partition_data.data()),
290                                 partition_data.size() * sizeof(int32_t));
291     fingerprint = CombineFingerprints(fingerprint, partition_fingerprint);
292   }
293 
294   // Get a fingerprint-specific lock that is passed to the SerializationKey, to
295   // ensure noone else gets access to an equivalent SerializationKey.
296   return SerializationEntry(cache_dir_, model_token_, fingerprint);
297 }
298 
SaveDelegatedNodes(TfLiteContext * context,Serialization * serialization,const std::string & delegate_id,const TfLiteIntArray * node_ids)299 TfLiteStatus SaveDelegatedNodes(TfLiteContext* context,
300                                 Serialization* serialization,
301                                 const std::string& delegate_id,
302                                 const TfLiteIntArray* node_ids) {
303   if (!node_ids) return kTfLiteError;
304   std::string cache_key = delegate_id + kDelegatedNodesSuffix;
305   auto entry = serialization->GetEntryForDelegate(cache_key, context);
306   return entry.SetData(context, reinterpret_cast<const char*>(node_ids),
307                        (1 + node_ids->size) * sizeof(int));
308 }
309 
GetDelegatedNodes(TfLiteContext * context,Serialization * serialization,const std::string & delegate_id,TfLiteIntArray ** node_ids)310 TfLiteStatus GetDelegatedNodes(TfLiteContext* context,
311                                Serialization* serialization,
312                                const std::string& delegate_id,
313                                TfLiteIntArray** node_ids) {
314   if (!node_ids) return kTfLiteError;
315   std::string cache_key = delegate_id + kDelegatedNodesSuffix;
316   auto entry = serialization->GetEntryForDelegate(cache_key, context);
317 
318   std::string read_buffer;
319   TF_LITE_ENSURE_STATUS(entry.GetData(context, &read_buffer));
320   if (read_buffer.empty()) return kTfLiteOk;
321   *node_ids = TfLiteIntArrayCopy(
322       reinterpret_cast<const TfLiteIntArray*>(read_buffer.data()));
323   return kTfLiteOk;
324 }
325 
326 }  // namespace delegates
327 }  // namespace tflite
328