1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/serialization.h"
16
17 #if defined(_WIN32)
18 #include <fstream>
19 #include <iostream>
20 #else
21 #include <errno.h>
22 #include <fcntl.h>
23 #include <sys/file.h>
24 #include <unistd.h>
25
26 #include <cstring>
27 #endif // defined(_WIN32)
28
29 #include <time.h>
30
31 #include <algorithm>
32 #include <cstdint>
33 #include <memory>
34 #include <string>
35 #include <vector>
36
37 #include "tensorflow/lite/c/common.h"
38 #include "tensorflow/lite/minimal_logging.h"
39 #include "utils/hash/farmhash.h"
40
41 namespace tflite {
42 namespace delegates {
43 namespace {
44
45 static const char kDelegatedNodesSuffix[] = "_dnodes";
46
47 // Farmhash Fingerprint
CombineFingerprints(uint64_t l,uint64_t h)48 inline uint64_t CombineFingerprints(uint64_t l, uint64_t h) {
49 // Murmur-inspired hashing.
50 const uint64_t kMul = 0x9ddfea08eb382d69ULL;
51 uint64_t a = (l ^ h) * kMul;
52 a ^= (a >> 47);
53 uint64_t b = (h ^ a) * kMul;
54 b ^= (b >> 44);
55 b *= kMul;
56 b ^= (b >> 41);
57 b *= kMul;
58 return b;
59 }
60
JoinPath(const std::string & path1,const std::string & path2)61 inline std::string JoinPath(const std::string& path1,
62 const std::string& path2) {
63 return (path1.back() == '/') ? (path1 + path2) : (path1 + "/" + path2);
64 }
65
GetFilePath(const std::string & cache_dir,const std::string & model_token,const uint64_t fingerprint)66 inline std::string GetFilePath(const std::string& cache_dir,
67 const std::string& model_token,
68 const uint64_t fingerprint) {
69 auto file_name = (model_token + "_" + std::to_string(fingerprint) + ".bin");
70 return JoinPath(cache_dir, file_name);
71 }
72
73 } // namespace
74
StrFingerprint(const void * data,const size_t num_bytes)75 std::string StrFingerprint(const void* data, const size_t num_bytes) {
76 return std::to_string(
77 ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(
78 reinterpret_cast<const char*>(data), num_bytes));
79 }
80
SerializationEntry(const std::string & cache_dir,const std::string & model_token,const uint64_t fingerprint)81 SerializationEntry::SerializationEntry(const std::string& cache_dir,
82 const std::string& model_token,
83 const uint64_t fingerprint)
84 : cache_dir_(cache_dir),
85 model_token_(model_token),
86 fingerprint_(fingerprint) {}
87
SetData(TfLiteContext * context,const char * data,const size_t size) const88 TfLiteStatus SerializationEntry::SetData(TfLiteContext* context,
89 const char* data,
90 const size_t size) const {
91 auto filepath = GetFilePath(cache_dir_, model_token_, fingerprint_);
92 // Temporary file to write data to.
93 const std::string temp_filepath =
94 JoinPath(cache_dir_, (model_token_ + std::to_string(fingerprint_) +
95 std::to_string(time(nullptr))));
96
97 #if defined(_WIN32)
98 std::ofstream out_file(temp_filepath.c_str());
99 if (!out_file) {
100 TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "Could not create file: %s",
101 temp_filepath.c_str());
102 return kTfLiteDelegateDataWriteError;
103 }
104 out_file.write(data, size);
105 out_file.flush();
106 out_file.close();
107 // rename is an atomic operation in most systems.
108 if (rename(temp_filepath.c_str(), filepath.c_str()) < 0) {
109 TF_LITE_KERNEL_LOG(context, "Failed to rename to %s", filepath.c_str());
110 return kTfLiteDelegateDataWriteError;
111 }
112 #else // !defined(_WIN32)
113 // This method only works on unix/POSIX systems.
114 const int fd = open(temp_filepath.c_str(),
115 O_WRONLY | O_APPEND | O_CREAT | O_CLOEXEC, 0600);
116 if (fd < 0) {
117 TF_LITE_KERNEL_LOG(context, "Failed to open for writing: %s",
118 temp_filepath.c_str());
119 return kTfLiteDelegateDataWriteError;
120 }
121 // Loop until all bytes written.
122 ssize_t len = 0;
123 const char* buf = data;
124 do {
125 ssize_t ret = write(fd, buf, size);
126 if (ret <= 0) {
127 close(fd);
128 TF_LITE_KERNEL_LOG(context, "Failed to write data to: %s, error: %s",
129 temp_filepath.c_str(), std::strerror(errno));
130 return kTfLiteDelegateDataWriteError;
131 }
132
133 len += ret;
134 buf += ret;
135 } while (len < static_cast<ssize_t>(size));
136 // Use fsync to ensure data is on disk before renaming temp file.
137 if (fsync(fd) < 0) {
138 close(fd);
139 TF_LITE_KERNEL_LOG(context, "Could not fsync: %s, error: %s",
140 temp_filepath.c_str(), std::strerror(errno));
141 return kTfLiteDelegateDataWriteError;
142 }
143 if (close(fd) < 0) {
144 TF_LITE_KERNEL_LOG(context, "Could not close fd: %s, error: %s",
145 temp_filepath.c_str(), std::strerror(errno));
146 return kTfLiteDelegateDataWriteError;
147 }
148 if (rename(temp_filepath.c_str(), filepath.c_str()) < 0) {
149 TF_LITE_KERNEL_LOG(context, "Failed to rename to %s, error: %s",
150 filepath.c_str(), std::strerror(errno));
151 return kTfLiteDelegateDataWriteError;
152 }
153 #endif // defined(_WIN32)
154
155 TFLITE_LOG(TFLITE_LOG_INFO, "Wrote serialized data for model %s (%d B) to %s",
156 model_token_.c_str(), size, filepath.c_str());
157
158 return kTfLiteOk;
159 }
160
GetData(TfLiteContext * context,std::string * data) const161 TfLiteStatus SerializationEntry::GetData(TfLiteContext* context,
162 std::string* data) const {
163 if (!data) return kTfLiteError;
164 auto filepath = GetFilePath(cache_dir_, model_token_, fingerprint_);
165
166 #if defined(_WIN32)
167 std::ifstream cache_stream(filepath,
168 std::ios_base::in | std::ios_base::binary);
169 if (cache_stream.good()) {
170 cache_stream.seekg(0, cache_stream.end);
171 int cache_size = cache_stream.tellg();
172 cache_stream.seekg(0, cache_stream.beg);
173
174 data->resize(cache_size);
175 cache_stream.read(&(*data)[0], cache_size);
176 cache_stream.close();
177 }
178 #else // !defined(_WIN32)
179 // This method only works on unix/POSIX systems, but is more optimized & has
180 // lower size overhead for Android binaries.
181 data->clear();
182 // O_CLOEXEC is needed for correctness, as another thread may call
183 // popen() and the callee inherit the lock if it's not O_CLOEXEC.
184 int fd = open(filepath.c_str(), O_RDONLY | O_CLOEXEC, 0600);
185 if (fd < 0) {
186 TF_LITE_KERNEL_LOG(context, "File %s couldn't be opened for reading: %s",
187 filepath.c_str(), std::strerror(errno));
188 return kTfLiteDelegateDataNotFound;
189 }
190 int lock_status = flock(fd, LOCK_EX);
191 if (lock_status < 0) {
192 close(fd);
193 TF_LITE_KERNEL_LOG(context, "Could not flock %s: %s", filepath.c_str(),
194 std::strerror(errno));
195 return kTfLiteDelegateDataReadError;
196 }
197 char buffer[512];
198 while (true) {
199 int bytes_read = read(fd, buffer, 512);
200 if (bytes_read == 0) {
201 // EOF
202 close(fd);
203 return kTfLiteOk;
204 } else if (bytes_read < 0) {
205 close(fd);
206 TF_LITE_KERNEL_LOG(context, "Error reading %s: %s", filepath.c_str(),
207 std::strerror(errno));
208 return kTfLiteDelegateDataReadError;
209 } else {
210 data->append(buffer, bytes_read);
211 }
212 }
213 #endif // defined(_WIN32)
214
215 TFLITE_LOG_PROD(TFLITE_LOG_INFO,
216 "Found serialized data for model %s (%d B) at %s",
217 model_token_.c_str(), data->size(), filepath.c_str());
218
219 if (!data->empty()) {
220 TFLITE_LOG(TFLITE_LOG_INFO, "Data found at %s: %d bytes", filepath.c_str(),
221 data->size());
222 return kTfLiteOk;
223 } else {
224 TF_LITE_KERNEL_LOG(context, "No serialized data found: %s",
225 filepath.c_str());
226 return kTfLiteDelegateDataNotFound;
227 }
228 }
229
GetEntryImpl(const std::string & custom_key,TfLiteContext * context,const TfLiteDelegateParams * delegate_params)230 SerializationEntry Serialization::GetEntryImpl(
231 const std::string& custom_key, TfLiteContext* context,
232 const TfLiteDelegateParams* delegate_params) {
233 // First incorporate model_token.
234 // We use Fingerprint64 instead of std::hash, since the latter isn't
235 // guaranteed to be stable across runs. See b/172237993.
236 uint64_t fingerprint =
237 ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(
238 model_token_.c_str(), model_token_.size());
239
240 // Incorporate custom_key.
241 const uint64_t custom_str_fingerprint =
242 ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(
243 custom_key.c_str(), custom_key.size());
244 fingerprint = CombineFingerprints(fingerprint, custom_str_fingerprint);
245
246 // Incorporate context details, if provided.
247 // A quick heuristic involving graph tensors to 'fingerprint' a
248 // tflite::Subgraph. We don't consider the execution plan, since it could be
249 // in flux if the delegate uses this method during
250 // ReplaceNodeSubsetsWithDelegateKernels (eg in kernel Init).
251 if (context) {
252 std::vector<int32_t> context_data;
253 // Number of tensors can be large.
254 const int tensors_to_consider = std::min<int>(context->tensors_size, 100);
255 context_data.reserve(1 + tensors_to_consider);
256 context_data.push_back(context->tensors_size);
257 for (int i = 0; i < tensors_to_consider; ++i) {
258 context_data.push_back(context->tensors[i].bytes);
259 }
260 const uint64_t context_fingerprint =
261 ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(
262 reinterpret_cast<char*>(context_data.data()),
263 context_data.size() * sizeof(int32_t));
264 fingerprint = CombineFingerprints(fingerprint, context_fingerprint);
265 }
266
267 // Incorporate delegated partition details, if provided.
268 // A quick heuristic that considers the nodes & I/O tensor sizes to
269 // fingerprint TfLiteDelegateParams.
270 if (delegate_params) {
271 std::vector<int32_t> partition_data;
272 auto* nodes = delegate_params->nodes_to_replace;
273 auto* input_tensors = delegate_params->input_tensors;
274 auto* output_tensors = delegate_params->output_tensors;
275 partition_data.reserve(nodes->size + input_tensors->size +
276 output_tensors->size);
277 partition_data.insert(partition_data.end(), nodes->data,
278 nodes->data + nodes->size);
279 for (int i = 0; i < input_tensors->size; ++i) {
280 auto& tensor = context->tensors[input_tensors->data[i]];
281 partition_data.push_back(tensor.bytes);
282 }
283 for (int i = 0; i < output_tensors->size; ++i) {
284 auto& tensor = context->tensors[output_tensors->data[i]];
285 partition_data.push_back(tensor.bytes);
286 }
287 const uint64_t partition_fingerprint =
288 ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(
289 reinterpret_cast<char*>(partition_data.data()),
290 partition_data.size() * sizeof(int32_t));
291 fingerprint = CombineFingerprints(fingerprint, partition_fingerprint);
292 }
293
294 // Get a fingerprint-specific lock that is passed to the SerializationKey, to
295 // ensure noone else gets access to an equivalent SerializationKey.
296 return SerializationEntry(cache_dir_, model_token_, fingerprint);
297 }
298
SaveDelegatedNodes(TfLiteContext * context,Serialization * serialization,const std::string & delegate_id,const TfLiteIntArray * node_ids)299 TfLiteStatus SaveDelegatedNodes(TfLiteContext* context,
300 Serialization* serialization,
301 const std::string& delegate_id,
302 const TfLiteIntArray* node_ids) {
303 if (!node_ids) return kTfLiteError;
304 std::string cache_key = delegate_id + kDelegatedNodesSuffix;
305 auto entry = serialization->GetEntryForDelegate(cache_key, context);
306 return entry.SetData(context, reinterpret_cast<const char*>(node_ids),
307 (1 + node_ids->size) * sizeof(int));
308 }
309
GetDelegatedNodes(TfLiteContext * context,Serialization * serialization,const std::string & delegate_id,TfLiteIntArray ** node_ids)310 TfLiteStatus GetDelegatedNodes(TfLiteContext* context,
311 Serialization* serialization,
312 const std::string& delegate_id,
313 TfLiteIntArray** node_ids) {
314 if (!node_ids) return kTfLiteError;
315 std::string cache_key = delegate_id + kDelegatedNodesSuffix;
316 auto entry = serialization->GetEntryForDelegate(cache_key, context);
317
318 std::string read_buffer;
319 TF_LITE_ENSURE_STATUS(entry.GetData(context, &read_buffer));
320 if (read_buffer.empty()) return kTfLiteOk;
321 *node_ids = TfLiteIntArrayCopy(
322 reinterpret_cast<const TfLiteIntArray*>(read_buffer.data()));
323 return kTfLiteOk;
324 }
325
326 } // namespace delegates
327 } // namespace tflite
328