• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "nncompiled_cache.h"
17 
18 #include <unistd.h>
19 #include <functional>
20 #include <memory>
21 
22 #include "common/utils.h"
23 #include "backend_manager.h"
24 #include "nnbackend.h"
25 
26 namespace OHOS {
27 namespace NeuralNetworkRuntime {
28 constexpr int MAX_MODEL_SIZE = 200 * 1024 * 1024; // 200MB
29 constexpr int NULL_PTR_LENGTH = 0;
30 constexpr int NUMBER_CACHE_INFO_MEMBERS = 3;
31 constexpr int HEX_UNIT = 16;
32 
Save(const std::vector<OHOS::NeuralNetworkRuntime::Buffer> & caches,const std::string & cacheDir,uint32_t version)33 OH_NN_ReturnCode NNCompiledCache::Save(const std::vector<OHOS::NeuralNetworkRuntime::Buffer>& caches,
34                                        const std::string& cacheDir,
35                                        uint32_t version)
36 {
37     if (caches.empty()) {
38         LOGE("[NNCompiledCache] Save failed, caches is empty.");
39         return OH_NN_INVALID_PARAMETER;
40     }
41 
42     if (m_device == nullptr) {
43         LOGE("[NNCompiledCache] Save failed, m_device is empty.");
44         return OH_NN_INVALID_PARAMETER;
45     }
46 
47     OH_NN_ReturnCode ret = GenerateCacheFiles(caches, cacheDir, version);
48     if (ret != OH_NN_SUCCESS) {
49         LOGE("[NNCompiledCache] Save failed, error happened when calling GenerateCacheFiles.");
50         return ret;
51     }
52 
53     LOGI("[NNCompiledCache] Save success. %zu caches are saved.", caches.size());
54     return OH_NN_SUCCESS;
55 }
56 
Restore(const std::string & cacheDir,uint32_t version,std::vector<OHOS::NeuralNetworkRuntime::Buffer> & caches)57 OH_NN_ReturnCode NNCompiledCache::Restore(const std::string& cacheDir,
58                                           uint32_t version,
59                                           std::vector<OHOS::NeuralNetworkRuntime::Buffer>& caches)
60 {
61     if (cacheDir.empty()) {
62         LOGE("[NNCompiledCache] Restore failed, cacheDir is empty.");
63         return OH_NN_INVALID_PARAMETER;
64     }
65 
66     if (!caches.empty()) {
67         LOGE("[NNCompiledCache] Restore failed, caches is not empty.");
68         return OH_NN_INVALID_PARAMETER;
69     }
70 
71     if (m_device == nullptr) {
72         LOGE("[NNCompiledCache] Restore failed, m_device is empty.");
73         return OH_NN_INVALID_PARAMETER;
74     }
75 
76     std::string cacheInfoPath = cacheDir + "/" + m_modelName + "cache_info.nncache";
77     char path[PATH_MAX];
78     if (realpath(cacheInfoPath.c_str(), path) == nullptr) {
79         LOGE("[NNCompiledCache] Restore failed, fail to get the real path of cacheInfoPath.");
80         return OH_NN_INVALID_PARAMETER;
81     }
82     if (access(cacheInfoPath.c_str(), F_OK) != 0) {
83         LOGE("[NNCompiledCache] Restore failed, cacheInfoPath is not exist.");
84         return OH_NN_INVALID_PARAMETER;
85     }
86 
87     NNCompiledCacheInfo cacheInfo;
88     OH_NN_ReturnCode ret = CheckCacheInfo(cacheInfo, cacheInfoPath);
89     if (ret != OH_NN_SUCCESS) {
90         LOGE("[NNCompiledCache] Restore failed, error happened when calling CheckCacheInfo.");
91         return ret;
92     }
93 
94     if (static_cast<uint64_t>(version) > cacheInfo.version) {
95         LOGE("[NNCompiledCache] Restore failed, version is not match. The current version is %{public}u, "
96              "but the cache files version is %{public}zu.",
97              version,
98              static_cast<size_t>(cacheInfo.version));
99         return OH_NN_INVALID_PARAMETER;
100     }
101 
102     if (static_cast<uint64_t>(version) < cacheInfo.version) {
103         LOGE("[NNCompiledCache] Restore failed, the current version is lower than the cache files, "
104              "please set a higher version.");
105         return OH_NN_OPERATION_FORBIDDEN;
106     }
107 
108     for (uint32_t i = 0; i < cacheInfo.fileNumber; ++i) {
109         std::string cacheModelPath = cacheDir + "/" + m_modelName + std::to_string(i) + ".nncache";
110         if (access(cacheModelPath.c_str(), 0) != 0) {
111             LOGE("[NNCompiledCache] Restore failed, %{public}s is not exist.", cacheModelPath.c_str());
112             return OH_NN_INVALID_PARAMETER;
113         }
114 
115         OHOS::NeuralNetworkRuntime::Buffer modelBuffer;
116         ret = ReadCacheModelFile(cacheModelPath, modelBuffer);
117         if (ret != OH_NN_SUCCESS) {
118             LOGE("[NNCompiledCache] Restore failed, error happened when calling ReadCacheModelFile.");
119             return ret;
120         }
121 
122         if (GetCrc16(static_cast<char*>(modelBuffer.data), modelBuffer.length) !=
123             cacheInfo.modelCheckSum[i]) {
124             LOGE("[NNCompiledCache] Restore failed, the cache model file %{public}s has been changed.",
125                  cacheModelPath.c_str());
126             return OH_NN_INVALID_FILE;
127         }
128 
129         caches.emplace_back(std::move(modelBuffer));
130     }
131 
132     return ret;
133 }
134 
SetBackend(size_t backendID)135 OH_NN_ReturnCode NNCompiledCache::SetBackend(size_t backendID)
136 {
137     const BackendManager& backendManager = BackendManager::GetInstance();
138     std::shared_ptr<Backend> backend = backendManager.GetBackend(backendID);
139     if (backend == nullptr) {
140         LOGE("[NNCompiledCache] SetBackend failed, backend with backendID %{public}zu is not exist.", backendID);
141         return OH_NN_INVALID_PARAMETER;
142     }
143 
144     std::shared_ptr<NNBackend> nnBackend = std::reinterpret_pointer_cast<NNBackend>(backend);
145     m_device = nnBackend->GetDevice();
146     if (m_device == nullptr) {
147         LOGE("[NNCompiledCache] SetBackend failed, device with backendID %{public}zu is not exist.", backendID);
148         return OH_NN_FAILED;
149     }
150 
151     m_backendID = backendID;
152     return OH_NN_SUCCESS;
153 }
154 
SetModelName(const std::string & modelName)155 void NNCompiledCache::SetModelName(const std::string& modelName)
156 {
157     m_modelName = modelName;
158 }
159 
GenerateCacheFiles(const std::vector<OHOS::NeuralNetworkRuntime::Buffer> & caches,const std::string & cacheDir,uint32_t version) const160 OH_NN_ReturnCode NNCompiledCache::GenerateCacheFiles(const std::vector<OHOS::NeuralNetworkRuntime::Buffer>& caches,
161                                                      const std::string& cacheDir,
162                                                      uint32_t version) const
163 {
164     const size_t cacheNumber = caches.size();
165     uint32_t cacheSize = NUMBER_CACHE_INFO_MEMBERS + cacheNumber;
166     std::unique_ptr<uint64_t[]> cacheInfo = CreateUniquePtr<uint64_t[]>(cacheSize);
167     if (cacheInfo == nullptr) {
168         LOGE("[NNCompiledCache] GenerateCacheFiles failed, fail to create cacheInfo instance.");
169         return OH_NN_MEMORY_ERROR;
170     }
171 
172     OH_NN_ReturnCode ret = GenerateCacheModel(caches, cacheInfo, cacheDir, version);
173     if (ret != OH_NN_SUCCESS) {
174         LOGE("[NNCompiledCache] GenerateCacheFiles failed, error happened when calling GenerateCacheModel.");
175         return ret;
176     }
177 
178     uint32_t infoCharNumber = cacheSize * sizeof(uint64_t);
179     ret = WriteCacheInfo(infoCharNumber, cacheInfo, cacheDir);
180     if (ret != OH_NN_SUCCESS) {
181         LOGE("[NNCompiledCache] GenerateCacheFiles failed, error happened when calling WriteCacheInfo.");
182         return ret;
183     }
184 
185     return OH_NN_SUCCESS;
186 }
187 
GenerateCacheModel(const std::vector<OHOS::NeuralNetworkRuntime::Buffer> & caches,std::unique_ptr<uint64_t[]> & cacheInfo,const std::string & cacheDir,uint32_t version) const188 OH_NN_ReturnCode NNCompiledCache::GenerateCacheModel(const std::vector<OHOS::NeuralNetworkRuntime::Buffer>& caches,
189                                                      std::unique_ptr<uint64_t[]>& cacheInfo,
190                                                      const std::string& cacheDir,
191                                                      uint32_t version) const
192 {
193     size_t cacheNumber = caches.size();
194 
195     auto cacheInfoPtr = cacheInfo.get();
196     *cacheInfoPtr++ = static_cast<uint64_t>(cacheNumber);
197     *cacheInfoPtr++ = static_cast<uint64_t>(version);
198     *cacheInfoPtr++ = static_cast<uint64_t>(m_backendID); // Should call SetBackend first.
199 
200     for (size_t i = 0; i < cacheNumber; ++i) {
201         std::string cacheModelFile = cacheDir + "/" + m_modelName + std::to_string(i) + ".nncache";
202         std::ofstream cacheModelStream(cacheModelFile, std::ios::binary | std::ios::out | std::ios::trunc);
203         if (cacheModelStream.fail()) {
204             LOGE("[NNCompiledCache] GenerateCacheModel failed, model cache file is invalid.");
205             return OH_NN_INVALID_PARAMETER;
206         }
207 
208         uint64_t checkSum =
209             static_cast<uint64_t>(GetCrc16(static_cast<char*>(caches[i].data), caches[i].length));
210         *cacheInfoPtr++ = checkSum;
211         if (!cacheModelStream.write(static_cast<const char*>(caches[i].data), caches[i].length)) {
212             LOGE("[NNCompiledCache] GenerateCacheModel failed, fail to write cache model.");
213             cacheModelStream.close();
214             return OH_NN_SAVE_CACHE_EXCEPTION;
215         };
216 
217         cacheModelStream.close();
218     }
219 
220     return OH_NN_SUCCESS;
221 }
222 
WriteCacheInfo(uint32_t cacheSize,std::unique_ptr<uint64_t[]> & cacheInfo,const std::string & cacheDir) const223 OH_NN_ReturnCode NNCompiledCache::WriteCacheInfo(uint32_t cacheSize,
224                                                  std::unique_ptr<uint64_t[]>& cacheInfo,
225                                                  const std::string& cacheDir) const
226 {
227     std::string cacheInfoPath = cacheDir + "/" + m_modelName + "cache_info.nncache";
228     std::ofstream cacheInfoStream(cacheInfoPath, std::ios::binary | std::ios::out | std::ios::trunc);
229     if (cacheInfoStream.fail()) {
230         LOGE("[NNCompiledCache] WriteCacheInfo failed, model cache info file is invalid.");
231         return OH_NN_INVALID_FILE;
232     }
233 
234     if (!cacheInfoStream.write(reinterpret_cast<const char*>(cacheInfo.get()), cacheSize)) {
235         LOGE("[NNCompiledCache] WriteCacheInfo failed, fail to write cache info.");
236         cacheInfoStream.close();
237         return OH_NN_SAVE_CACHE_EXCEPTION;
238     }
239 
240     cacheInfoStream.close();
241     return OH_NN_SUCCESS;
242 }
243 
CheckCacheInfo(NNCompiledCacheInfo & modelCacheInfo,const std::string & cacheInfoPath) const244 OH_NN_ReturnCode NNCompiledCache::CheckCacheInfo(NNCompiledCacheInfo& modelCacheInfo,
245                                                  const std::string& cacheInfoPath) const
246 {
247     // cacheInfoPath is validated outside.
248     std::ifstream infoCacheFile(cacheInfoPath.c_str(), std::ios::in | std::ios::binary);
249     if (!infoCacheFile) {
250         LOGE("[NNCompiledCache] CheckCacheInfo failed, error happened when opening cache info file.");
251         return OH_NN_INVALID_FILE;
252     }
253 
254     int charNumber = NUMBER_CACHE_INFO_MEMBERS * sizeof(uint64_t);
255     if (!infoCacheFile.read(reinterpret_cast<char*>(&(modelCacheInfo)), charNumber)) {
256         LOGE("[NNCompiledCache] CheckCacheInfo failed, error happened when reading cache info file.");
257         infoCacheFile.close();
258         return OH_NN_INVALID_FILE;
259     }
260 
261     // modelCacheInfo.deviceId type is int64_t,
262     // it is transformed from size_t value, so the transform here will not truncate value.
263     size_t deviceId = static_cast<size_t>(modelCacheInfo.deviceId);
264     if (deviceId != m_backendID) {
265         LOGE("[NNCompiledCache] CheckCacheInfo failed. The deviceId=%{public}zu in the cache files "
266              "is different from current deviceId=%{public}zu,"
267              "please change the cache directory or current deviceId.",
268              deviceId,
269              m_backendID);
270         infoCacheFile.close();
271         return OH_NN_INVALID_PARAMETER;
272     }
273 
274     std::vector<uint64_t> modelCheckSum;
275     modelCheckSum.resize(modelCacheInfo.fileNumber);
276     modelCacheInfo.modelCheckSum.resize(modelCacheInfo.fileNumber);
277     if (!infoCacheFile.read(reinterpret_cast<char*>(&modelCheckSum[0]),
278         modelCacheInfo.fileNumber * sizeof(uint64_t))) {
279         LOGE("[NNCompiledCache] CheckCacheInfo failed. The info cache file has been changed.");
280         infoCacheFile.close();
281         return OH_NN_INVALID_FILE;
282     }
283 
284     for (uint32_t i = 0; i < modelCacheInfo.fileNumber; ++i) {
285         modelCacheInfo.modelCheckSum[i] = static_cast<unsigned short>(modelCheckSum[i]);
286     }
287 
288     return OH_NN_SUCCESS;
289 }
290 
ReadCacheModelFile(const std::string & filePath,OHOS::NeuralNetworkRuntime::Buffer & cache) const291 OH_NN_ReturnCode NNCompiledCache::ReadCacheModelFile(const std::string& filePath,
292                                                      OHOS::NeuralNetworkRuntime::Buffer& cache) const
293 {
294     // filePath is validate in NNCompiledCache::Restore, no need to check again.
295     std::ifstream ifs(filePath.c_str(), std::ios::in | std::ios::binary);
296     if (!ifs) {
297         LOGE("[NNCompiledCache] ReadCacheModelFile failed, file is invalid.");
298         return OH_NN_INVALID_FILE;
299     }
300 
301     int fsize{-1};
302     OH_NN_ReturnCode ret = GetCacheFileLength(ifs, fsize);
303     if (ret != OH_NN_SUCCESS) {
304         ifs.close();
305         LOGE("[NNCompiledCache] ReadCacheModelFile failed, get file %{public}s length fialed.", filePath.c_str());
306         return ret;
307     }
308 
309     ifs.seekg(0, std::ios::beg);
310     if (!ifs.good()) {
311         LOGE("[NNCompiledCache] ReadCacheModelFile failed, file is invalid.");
312         ifs.close();
313         return OH_NN_INVALID_FILE;
314     }
315 
316     char* ptr = static_cast<char*>(m_device->AllocateBuffer(fsize));
317     if (ptr == nullptr) {
318         LOGE("[NNCompiledCache] ReadCacheModelFile failed, failed to allocate memory.");
319         ifs.close();
320         return OH_NN_MEMORY_ERROR;
321     }
322 
323     ifs.read(ptr, fsize);
324     if (!ifs.good()) {
325         LOGE("[NNCompiledCache] ReadCacheModelFile failed, failed to read file.");
326         ifs.close();
327         m_device->ReleaseBuffer(ptr);
328         ptr = nullptr;
329         return OH_NN_INVALID_FILE;
330     }
331 
332     ifs.close();
333     cache.data = ptr;
334     cache.length = static_cast<size_t>(fsize); // fsize should be non-negative, safe to cast.
335     return OH_NN_SUCCESS;
336 }
337 
GetCrc16(char * buffer,size_t length) const338 unsigned short NNCompiledCache::GetCrc16(char* buffer, size_t length) const
339 {
340     unsigned int sum = 0;
341     while (length > 1) {
342         sum += *(reinterpret_cast<unsigned short*>(buffer));
343         length -= sizeof(unsigned short);
344         buffer += sizeof(unsigned short);
345     }
346 
347     if (length > 0) {
348         sum += *(reinterpret_cast<unsigned char*>(buffer));
349     }
350 
351     while (sum >> HEX_UNIT) {
352         sum = (sum >> HEX_UNIT) + (sum & 0xffff);
353     }
354 
355     return static_cast<unsigned short>(~sum);
356 }
357 
GetCacheFileLength(std::ifstream & ifs,int & fileSize) const358 OH_NN_ReturnCode NNCompiledCache::GetCacheFileLength(std::ifstream& ifs, int& fileSize) const
359 {
360     ifs.seekg(0, std::ios::end);
361     if (!ifs.good()) {
362         LOGE("[NNCompiledCache] GetCacheFileLength failed, fail to set the position of the next character "
363              "to be extracted from the input stream.");
364         return OH_NN_FAILED;
365     }
366 
367     int handleValue = ifs.tellg();
368     if (handleValue == -1) {
369         LOGE("[NNCompiledCache] GetCacheFileLength failed, fail to get position of the input stream.");
370         return OH_NN_INVALID_FILE;
371     }
372 
373     if ((handleValue > MAX_MODEL_SIZE) || (handleValue == NULL_PTR_LENGTH)) {
374         LOGE("[NNCompiledCache] GetCacheFileLength failed, unable to read huge or empty input stream, "
375              "get cache file size=%{public}d",
376              handleValue);
377         return OH_NN_INVALID_FILE;
378     }
379 
380     fileSize = handleValue;
381     return OH_NN_SUCCESS;
382 }
383 } // namespace NeuralNetworkRuntime
384 } // namespace OHOS
385