• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "compilation.h"
17 
18 #include <sys/stat.h>
19 #include <unistd.h>
20 #include <cstdio>
21 #include <sys/types.h>
22 #include <fstream>
23 #include <climits>
24 
25 #include "common/utils.h"
26 #include "common/scoped_trace.h"
27 #include "validation.h"
28 #include "device_manager.h"
29 
30 namespace OHOS {
31 namespace NeuralNetworkRuntime {
32 constexpr int MAX_MODEL_SIZE = 200 * 1024 * 1024; // 200MB
33 constexpr int OCT_UNIT = 8;
34 constexpr int NULL_PTR_LENGTH = 0;
35 constexpr int NUMBER_CACHE_INFO_MEMBERS = 3;
36 
37 // CRC16 Table is created based on the Polynomial of G(x) = x^16 + x^12 + x^15 + 1 and
38 // CRC register initialization value of "0" (0x0000)
39 static const unsigned short CRC16_TAB[256] = {
40     0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7,
41     0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef,
42     0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6,
43     0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de,
44     0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485,
45     0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d,
46     0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4,
47     0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc,
48     0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823,
49     0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b,
50     0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12,
51     0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a,
52     0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41,
53     0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49,
54     0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70,
55     0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78,
56     0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f,
57     0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067,
58     0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e,
59     0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256,
60     0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d,
61     0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405,
62     0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c,
63     0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634,
64     0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab,
65     0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3,
66     0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a,
67     0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92,
68     0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9,
69     0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1,
70     0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8,
71     0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0
72 };
73 
Compilation(const InnerModel * innerModel)74 Compilation::Compilation(const InnerModel* innerModel)
75     : m_liteGraph(innerModel->GetLiteGraphs()),
76     m_inputTensors(innerModel->GetInputTensors()),
77     m_outputTensors(innerModel->GetOutputTensors()) {}
78 
SetDevice(size_t deviceId)79 OH_NN_ReturnCode Compilation::SetDevice(size_t deviceId)
80 {
81     if (m_isBuild) {
82         LOGE("Cannot set deviceId after compilation finish.");
83         return OH_NN_OPERATION_FORBIDDEN;
84     }
85 
86     auto& deviceManager = DeviceManager::GetInstance();
87     std::shared_ptr<Device> availableDevice = deviceManager.GetDevice(deviceId);
88     if (availableDevice == nullptr) {
89         LOGE("[Compilation] DeviceId does not exist, deviceId=%zu", deviceId);
90         return OH_NN_INVALID_PARAMETER;
91     }
92 
93     std::vector<bool> supportedList;
94     OH_NN_ReturnCode ret = availableDevice->GetSupportedOperation(m_liteGraph, supportedList);
95     if (ret != OH_NN_SUCCESS) {
96         LOGE("[Compilation] SetDevice failed, error happened when getting supported operation.");
97         return ret;
98     }
99 
100     for (bool isSupport : supportedList) {
101         if (!isSupport) {
102             LOGE("[Compilation] SetDevice failed, current device not support the model, device id: %zu.", deviceId);
103             return OH_NN_FAILED;
104         }
105     }
106 
107     bool supportDynamic;
108     ret = availableDevice->IsDynamicInputSupported(supportDynamic);
109     if (ret != OH_NN_SUCCESS) {
110         LOGE("[Compilation] SetDevice failed, error happened when checking whether device supports dynamic input.");
111         return ret;
112     }
113 
114     if (IsDynamicShape() && (!supportDynamic)) {
115         LOGE("[Compilation] SetDevice failed."
116              "The device does not support dynamic shape inputs, but the model has dynamic inputs.");
117         return OH_NN_FAILED;
118     }
119 
120     m_device = availableDevice;
121     m_deviceId = deviceId;
122     return OH_NN_SUCCESS;
123 }
124 
SetCacheDir(const std::string & cacheModelPath,uint32_t version)125 OH_NN_ReturnCode Compilation::SetCacheDir(const std::string& cacheModelPath, uint32_t version)
126 {
127     if (m_isBuild) {
128         LOGE("Cannot set cache after compilation finish.");
129         return OH_NN_OPERATION_FORBIDDEN;
130     }
131 
132     if (m_device == nullptr) {
133         LOGE("The parameter of m_device is nullptr, please call SetDevice function before calling SetCacheDir.");
134         return OH_NN_OPERATION_FORBIDDEN;
135     }
136 
137     bool isSupportedCache {false};
138     OH_NN_ReturnCode ret = m_device->IsModelCacheSupported(isSupportedCache);
139     if (ret != OH_NN_SUCCESS) {
140         LOGE("[Compilation] Fail to query whether the device is available to save cache model.");
141         return ret;
142     }
143 
144     if (!isSupportedCache) {
145         LOGE("[Compilation] The device is unavailable to save cache model.");
146         return OH_NN_OPERATION_FORBIDDEN;
147     }
148 
149     char realPathRes[PATH_MAX];
150     const char* filePath = realpath(cacheModelPath.c_str(), realPathRes);
151     if (filePath == nullptr) {
152         LOGE("[Compilation] The cache model path is invalid.");
153         return OH_NN_INVALID_PARAMETER;
154     }
155 
156     struct stat fileInfo;
157     if (stat(filePath, &fileInfo) != 0) {
158         LOGE("[Compilation] The cache directory does not exist or cannot be accessed.");
159         return OH_NN_INVALID_PARAMETER;
160     }
161 
162     if (!(fileInfo.st_mode & S_IFDIR)) {
163         LOGE("[Compilation] The cache model path is not a directory.");
164         return OH_NN_INVALID_PARAMETER;
165     }
166 
167     m_cachePath = (std::string)filePath + "/";
168     m_version = version;
169     return OH_NN_SUCCESS;
170 }
171 
SetPerformance(OH_NN_PerformanceMode performance)172 OH_NN_ReturnCode Compilation::SetPerformance(OH_NN_PerformanceMode performance)
173 {
174     if (m_isBuild) {
175         LOGE("[Compilation] Cannot set performance after compilation finish.");
176         return OH_NN_OPERATION_FORBIDDEN;
177     }
178 
179     if (m_device == nullptr) {
180         LOGE("Cannot set performance before set device, please set device first");
181         return OH_NN_OPERATION_FORBIDDEN;
182     }
183 
184     bool isSupportedPerformance {false};
185     OH_NN_ReturnCode ret = m_device->IsPerformanceModeSupported(isSupportedPerformance);
186     if (ret != OH_NN_SUCCESS) {
187         LOGE("[Compilation] Call device %zu failed.", m_deviceId);
188         return ret;
189     }
190 
191     if (!isSupportedPerformance) {
192         LOGE("[Compilation] This device %zu is not support performance setting.", m_deviceId);
193         return OH_NN_OPERATION_FORBIDDEN;
194     }
195 
196     if (!Validation::ValidatePerformanceMode(performance)) {
197         LOGE("[Compilation] SetPerformance passed invalid performance=%d", performance);
198         return OH_NN_INVALID_PARAMETER;
199     }
200 
201     m_performance = performance;
202     return OH_NN_SUCCESS;
203 }
204 
SetPriority(OH_NN_Priority priority)205 OH_NN_ReturnCode Compilation::SetPriority(OH_NN_Priority priority)
206 {
207     if (m_isBuild) {
208         LOGE("[Compilation] Cannot set priority after compilation finish.");
209         return OH_NN_OPERATION_FORBIDDEN;
210     }
211 
212     if (m_device == nullptr) {
213         LOGE("Cannot set priority before set device, please set device first");
214         return OH_NN_OPERATION_FORBIDDEN;
215     }
216 
217     bool isSupportedPriority {false};
218     OH_NN_ReturnCode ret = m_device->IsPrioritySupported(isSupportedPriority);
219     if (ret != OH_NN_SUCCESS) {
220         LOGE("[Compilation] Call device %zu failed.", m_deviceId);
221         return ret;
222     }
223 
224     if (!isSupportedPriority) {
225         LOGE("[Compilation] This device %zu is not support priority setting.", m_deviceId);
226         return OH_NN_OPERATION_FORBIDDEN;
227     }
228 
229     if (!Validation::ValidatePriority(priority)) {
230         LOGE("[Compilation] SetPriority passed invalid priority=%d", priority);
231         return OH_NN_INVALID_PARAMETER;
232     }
233 
234     m_priority = priority;
235     return OH_NN_SUCCESS;
236 }
237 
SetEnableFp16(bool isFp16)238 OH_NN_ReturnCode Compilation::SetEnableFp16(bool isFp16)
239 {
240     if (m_isBuild) {
241         LOGE("[Compilation] Cannot enable float16 after compilation finish.");
242         return OH_NN_OPERATION_FORBIDDEN;
243     }
244 
245     if (m_device == nullptr) {
246         LOGE("Cannot set enable fp16 before set device, please set device first");
247         return OH_NN_OPERATION_FORBIDDEN;
248     }
249 
250     bool isSupportedFp16 {false};
251     OH_NN_ReturnCode ret = m_device->IsFloat16PrecisionSupported(isSupportedFp16);
252     if (ret != OH_NN_SUCCESS) {
253         LOGE("[Compilation] Call device %zu failed.", m_deviceId);
254         return ret;
255     }
256 
257     if (!isSupportedFp16) {
258         LOGE("[Compilation] This device %zu is not support float16 precision setting.", m_deviceId);
259         return OH_NN_OPERATION_FORBIDDEN;
260     }
261 
262     m_enableFp16 = isFp16;
263     return OH_NN_SUCCESS;
264 }
265 
GetCrc16(const unsigned char * buffer,size_t length) const266 unsigned short Compilation::GetCrc16(const unsigned char* buffer, size_t length) const
267 {
268     unsigned short crc16 = 0;
269     for (size_t i = 0; i < length; ++i) {
270         uint8_t tableIndex = ((crc16 >> OCT_UNIT) ^ *buffer++) & 0x00ff;
271         crc16 = (crc16 << OCT_UNIT) ^ CRC16_TAB[tableIndex];
272     }
273     return crc16;
274 }
275 
GenerateCacheInfo(uint32_t cacheSize,std::unique_ptr<uint64_t[]> & cacheInfo) const276 OH_NN_ReturnCode Compilation::GenerateCacheInfo(uint32_t cacheSize, std::unique_ptr<uint64_t[]>& cacheInfo) const
277 {
278     std::string cacheInfoPath = m_cachePath + "cache_info.nncache";
279     std::ofstream cacheInfoStream(cacheInfoPath, std::ios::binary | std::ios::out | std::ios::trunc);
280     if (cacheInfoStream.fail()) {
281         LOGE("[Compilation] Model cache info file is invalid.");
282         return OH_NN_INVALID_FILE;
283     }
284 
285     if (!cacheInfoStream.write(reinterpret_cast<const char*>(cacheInfo.get()), cacheSize)) {
286         LOGE("[Compilation] Fail to write cache info.");
287         cacheInfoStream.close();
288         return OH_NN_FAILED;
289     }
290 
291     cacheInfoStream.close();
292     return OH_NN_SUCCESS;
293 }
294 
GenerateCacheModel(size_t cacheNumber,std::unique_ptr<uint64_t[]> & cacheInfo,std::vector<ModelBuffer> modelBuffer) const295 OH_NN_ReturnCode Compilation::GenerateCacheModel(size_t cacheNumber, std::unique_ptr<uint64_t[]>& cacheInfo,
296     std::vector<ModelBuffer> modelBuffer) const
297 {
298     auto cacheInfoPtr = cacheInfo.get();
299     *cacheInfoPtr++ = static_cast<uint64_t>(cacheNumber);
300     *cacheInfoPtr++ = static_cast<uint64_t>(m_version);
301     *cacheInfoPtr++ = static_cast<uint64_t>(m_deviceId);
302     for (uint32_t i = 0; i < cacheNumber; ++i) {
303         std::string cacheModelFile = m_cachePath + std::to_string(i) + ".nncache";
304         std::ofstream cacheModelStream(cacheModelFile, std::ios::binary | std::ios::out | std::ios::trunc);
305         if (cacheModelStream.fail()) {
306             LOGE("[Compilation] Model cache file is invalid.");
307             return OH_NN_INVALID_FILE;
308         }
309 
310         uint64_t checkSum = static_cast<uint64_t>(GetCrc16(static_cast<const unsigned char*>(modelBuffer[i].buffer),
311             modelBuffer[i].length));
312         *cacheInfoPtr++ = checkSum;
313         if (!cacheModelStream.write(static_cast<const char*>(modelBuffer[i].buffer), modelBuffer[i].length)) {
314             LOGE("[Compilation] Fail to write cache model.");
315             cacheModelStream.close();
316             return OH_NN_FAILED;
317         };
318 
319         cacheModelStream.close();
320     }
321 
322     return OH_NN_SUCCESS;
323 }
324 
GenerateCacheFiles(const std::vector<ModelBuffer> & modelBuffer) const325 OH_NN_ReturnCode Compilation::GenerateCacheFiles(const std::vector<ModelBuffer>& modelBuffer) const
326 {
327     const size_t cacheNumber = modelBuffer.size();
328     uint32_t cacheSize = NUMBER_CACHE_INFO_MEMBERS + cacheNumber;
329     std::unique_ptr<uint64_t[]> cacheInfo = std::make_unique<uint64_t[]>(cacheSize);
330     if (cacheInfo == nullptr) {
331         LOGE("Fail to create cacheInfo instance.");
332         return OH_NN_MEMORY_ERROR;
333     }
334 
335     OH_NN_ReturnCode ret = GenerateCacheModel(cacheNumber, cacheInfo, modelBuffer);
336     if (ret != OH_NN_SUCCESS) {
337         return ret;
338     }
339 
340     uint32_t infoCharNumber = cacheSize * sizeof(uint64_t);
341     ret = GenerateCacheInfo(infoCharNumber, cacheInfo);
342     if (ret != OH_NN_SUCCESS) {
343         return ret;
344     }
345 
346     return OH_NN_SUCCESS;
347 }
348 
GetCacheFileLength(std::ifstream & ifs,int & fsize) const349 OH_NN_ReturnCode Compilation::GetCacheFileLength(std::ifstream& ifs, int& fsize) const
350 {
351     ifs.seekg(0, std::ios::end);
352     if (!ifs.good()) {
353         LOGE("[Compilation] Fail to set the position of the next character to be extracted from the input stream.");
354         return OH_NN_INVALID_FILE;
355     }
356 
357     int handleValue = ifs.tellg();
358     if (handleValue == -1) {
359         LOGE("[Compilation] Unable to get position of the input stream.");
360         return OH_NN_INVALID_FILE;
361     }
362 
363     if ((handleValue > MAX_MODEL_SIZE) || (handleValue == NULL_PTR_LENGTH)) {
364         LOGE("[Compilation] Unable to read huge or empty input stream, get cache file size=%d", handleValue);
365         return OH_NN_INVALID_FILE;
366     }
367 
368     fsize = handleValue;
369     return OH_NN_SUCCESS;
370 }
371 
ReadCacheModelFile(const std::string & file,ModelBuffer & modelBuffer) const372 OH_NN_ReturnCode Compilation::ReadCacheModelFile(const std::string& file, ModelBuffer& modelBuffer) const
373 {
374     std::ifstream ifs(file.c_str(), std::ios::in | std::ios::binary);
375     if (!ifs) {
376         LOGE("[Compilation] Fail to open cache file.");
377         return OH_NN_INVALID_FILE;
378     }
379 
380     int fsize {-1};
381     OH_NN_ReturnCode ret = GetCacheFileLength(ifs, fsize);
382     if (ret != OH_NN_SUCCESS) {
383         ifs.close();
384         return ret;
385     }
386 
387     ifs.seekg(0, std::ios::beg);
388     if (!ifs.good()) {
389         LOGE("[Compilation] Fail to set the position of the next character to be extracted"
390             "from the cache model stream.");
391         ifs.close();
392         return OH_NN_FAILED;
393     }
394 
395     char* ptr = static_cast<char*>(m_device->AllocateBuffer(fsize));
396     if (ptr == nullptr) {
397         LOGE("[Compilation] Fail to create file buffer.");
398         ifs.close();
399         return OH_NN_NULL_PTR;
400     }
401 
402     ifs.read(ptr, fsize);
403     if (!ifs.good()) {
404         LOGE("[Compilation] Fail to read the characters from the cache model stream.");
405         ifs.close();
406         m_device->ReleaseBuffer(ptr);
407         ptr = nullptr;
408         return OH_NN_FAILED;
409     }
410 
411     ifs.close();
412     modelBuffer.buffer = ptr;
413     modelBuffer.length = fsize;
414     return OH_NN_SUCCESS;
415 }
416 
CheckCacheInfo(ModelCacheInfo & modelCacheInfo,const std::string & cacheInfoPath) const417 OH_NN_ReturnCode Compilation::CheckCacheInfo(ModelCacheInfo& modelCacheInfo, const std::string& cacheInfoPath) const
418 {
419     std::ifstream infoCacheFile(cacheInfoPath.c_str(), std::ios::in | std::ios::binary);
420     if (!infoCacheFile) {
421         LOGE("[Compilation] Openning cache info file failed.");
422         return OH_NN_INVALID_FILE;
423     }
424 
425     int charNumber = NUMBER_CACHE_INFO_MEMBERS * sizeof(uint64_t);
426     if (!infoCacheFile.read((char*)&(modelCacheInfo), charNumber)) {
427         LOGE("[Compilation] Fail to get the content of info cache file.");
428         infoCacheFile.close();
429         return OH_NN_INVALID_FILE;
430     }
431 
432     // modelCacheInfo.deviceId type is int64_t,
433     // it is transformed from size_t value, so the transform here will not truncate value.
434     size_t deviceId = static_cast<size_t>(modelCacheInfo.deviceId);
435     if (deviceId != m_deviceId) {
436         LOGE("[Compilation] The deviceId=%zu in the cache files is different from current deviceId=%zu,"
437             "please change the cache directory or current deviceId.", deviceId, m_deviceId);
438         infoCacheFile.close();
439         return OH_NN_INVALID_PARAMETER;
440     }
441 
442     std::vector<uint64_t> modelCheckSum;
443     modelCheckSum.resize(modelCacheInfo.fileNumber);
444     modelCacheInfo.modelCheckSum.resize(modelCacheInfo.fileNumber);
445     if (!infoCacheFile.read((char*)&modelCheckSum[0], modelCacheInfo.fileNumber * sizeof(uint64_t))) {
446         LOGE("[Compilation] The info cache file has been changed.");
447         infoCacheFile.close();
448         return OH_NN_INVALID_FILE;
449     }
450 
451     for (uint32_t i = 0; i < modelCacheInfo.fileNumber; ++i) {
452         modelCacheInfo.modelCheckSum[i] = static_cast<unsigned short>(modelCheckSum[i]);
453     }
454 
455     return OH_NN_SUCCESS;
456 }
457 
RemoveCacheFiles(uint32_t fileNumber) const458 OH_NN_ReturnCode Compilation::RemoveCacheFiles(uint32_t fileNumber) const
459 {
460     std::string cacheInfoPath = m_cachePath + "cache_info.nncache";
461     if (remove(cacheInfoPath.c_str()) == -1) {
462         LOGE("[Compilation] Fail to remove the file %s, please delete the file manually.", cacheInfoPath.c_str());
463         return OH_NN_FAILED;
464     }
465     LOGI("[Compilation] Succeed to remove the file cache_info.nncach.");
466 
467     for (uint32_t i = 0; i < fileNumber; ++i) {
468         std::string fileName = std::to_string(i) + ".nncache";
469         std::string cacheModelPath = m_cachePath + fileName;
470         if (access(cacheModelPath.c_str(), 0) != 0) {
471             LOGW("[Compilation] The file %s does not exist, no need to delete the file.", cacheModelPath.c_str());
472             continue;
473         }
474 
475         if (remove(cacheModelPath.c_str()) == -1) {
476             LOGE("[Compilation] Fail to remove the file %s, please delete the file manually.", cacheModelPath.c_str());
477             return OH_NN_FAILED;
478         }
479         LOGI("[Compilation] Succeed to remove the file %s", cacheModelPath.c_str());
480     }
481     return OH_NN_SUCCESS;
482 }
483 
CheckCacheModel(const ModelCacheInfo & modelCacheInfo,std::vector<ModelBuffer> & modelBuffers) const484 OH_NN_ReturnCode Compilation::CheckCacheModel(const ModelCacheInfo& modelCacheInfo,
485     std::vector<ModelBuffer>& modelBuffers) const
486 {
487     for (uint32_t i = 0; i < modelCacheInfo.fileNumber; ++i) {
488         std::string cacheModelPath = m_cachePath + std::to_string(i) + ".nncache";
489         if (access(cacheModelPath.c_str(), 0) != 0) {
490             LOGE("[Compilation] The cache model file %s does not exist.", cacheModelPath.c_str());
491             return OH_NN_INVALID_FILE;
492         }
493 
494         ModelBuffer modelBuffer;
495         OH_NN_ReturnCode ret = ReadCacheModelFile(cacheModelPath, modelBuffer);
496         if (ret != OH_NN_SUCCESS) {
497             LOGE("[Compilation] Read cache model file failed.");
498             return ret;
499         }
500 
501         if (GetCrc16(static_cast<const unsigned char*>(modelBuffer.buffer),
502             modelBuffer.length) != modelCacheInfo.modelCheckSum[i]) {
503             LOGE("[Compilation] The cache model file %s has been changed.", cacheModelPath.c_str());
504             return OH_NN_INVALID_FILE;
505         }
506 
507         modelBuffers.emplace_back(std::move(modelBuffer));
508     }
509 
510     return OH_NN_SUCCESS;
511 }
512 
NormalBuild(std::shared_ptr<PreparedModel> & preparedModel)513 OH_NN_ReturnCode Compilation::NormalBuild(std::shared_ptr<PreparedModel>& preparedModel)
514 {
515     ModelConfig config {m_enableFp16, m_performance, m_priority};
516     OH_NN_ReturnCode ret = m_device->PrepareModel(m_liteGraph, config, preparedModel);
517     if (ret != OH_NN_SUCCESS) {
518         LOGE("[Compilation] Preparing model failed when normally building.");
519         return ret;
520     }
521 
522     m_executionPlan = CreateSharedPtr<ExecutionPlan>(preparedModel, m_device);
523     if (m_executionPlan == nullptr) {
524         LOGE("Fail to create ExecutionPlan instance.");
525         return OH_NN_MEMORY_ERROR;
526     }
527 
528     return OH_NN_SUCCESS;
529 }
530 
GenCacheBuild(std::shared_ptr<PreparedModel> & preparedModel)531 OH_NN_ReturnCode Compilation::GenCacheBuild(std::shared_ptr<PreparedModel>& preparedModel)
532 {
533     OH_NN_ReturnCode ret = NormalBuild(preparedModel);
534     if (ret != OH_NN_SUCCESS) {
535         LOGE("[Compilation] Preparing model failed when generating cache.");
536         return ret;
537     }
538 
539     std::vector<ModelBuffer> modelBuffers;
540     ret = preparedModel->ExportModelCache(modelBuffers);
541     if (ret != OH_NN_SUCCESS) {
542         LOGE("[Compilation] Export model cache failed.");
543         return ret;
544     }
545 
546     ret = GenerateCacheFiles(modelBuffers);
547     if (ret != OH_NN_SUCCESS) {
548         LOGE("[Compilation] Generate cache files failed.");
549         return ret;
550     }
551 
552     LOGI("[Compilation] Export model cache successfully.");
553     return OH_NN_SUCCESS;
554 }
555 
ReGenCacheBuild(uint32_t fileNumber,std::shared_ptr<PreparedModel> & preparedModel)556 OH_NN_ReturnCode Compilation::ReGenCacheBuild(uint32_t fileNumber, std::shared_ptr<PreparedModel>& preparedModel)
557 {
558     OH_NN_ReturnCode ret = RemoveCacheFiles(fileNumber);
559     if (ret != OH_NN_SUCCESS) {
560         return ret;
561     }
562 
563     ret = GenCacheBuild(preparedModel);
564     if (ret != OH_NN_SUCCESS) {
565         LOGE("[Compilation] Generating cache building failed.");
566         return ret;
567     }
568 
569     LOGI("[Compilation] Update model cache successfully.");
570     return OH_NN_SUCCESS;
571 }
572 
LoadCacheBuild(std::shared_ptr<PreparedModel> & preparedModel,const ModelCacheInfo & cacheInfo)573 OH_NN_ReturnCode Compilation::LoadCacheBuild(std::shared_ptr<PreparedModel>& preparedModel,
574     const ModelCacheInfo& cacheInfo)
575 {
576     std::vector<ModelBuffer> modelBuffers;
577     OH_NN_ReturnCode ret = CheckCacheModel(cacheInfo, modelBuffers);
578     if (ret != OH_NN_SUCCESS) {
579         LOGE("[Compilation] Checking cache model failed.");
580         for (size_t i = 0; i < modelBuffers.size(); ++i) {
581             m_device->ReleaseBuffer(modelBuffers[i].buffer);
582             modelBuffers[i].buffer = nullptr;
583             modelBuffers[i].length = 0;
584         }
585         return ret;
586     }
587 
588     ModelConfig config {m_enableFp16, m_performance, m_priority};
589     ret = m_device->PrepareModelFromModelCache(modelBuffers, config, preparedModel);
590     if (ret != OH_NN_SUCCESS) {
591         LOGE("[Compilation] Preparing model from cache failed.");
592         return ret;
593     }
594 
595     LOGI("[Compilation] Load cache successfully.");
596 
597     m_executionPlan = CreateSharedPtr<ExecutionPlan>(preparedModel, m_device);
598     if (m_executionPlan == nullptr) {
599         LOGE("Fail to create ExecutionPlan instance.");
600         return OH_NN_MEMORY_ERROR;
601     }
602 
603     return OH_NN_SUCCESS;
604 }
605 
InnerBuild()606 OH_NN_ReturnCode Compilation::InnerBuild()
607 {
608     OH_NN_ReturnCode ret;
609     std::shared_ptr<PreparedModel> preparedModel;
610     if (m_cachePath.empty()) {
611         ret = NormalBuild(preparedModel);
612         if (ret != OH_NN_SUCCESS) {
613             LOGE("Fail to normally build.");
614             return ret;
615         }
616 
617         m_isBuild = true;
618         return OH_NN_SUCCESS;
619     }
620 
621     std::string cacheInfoPath = m_cachePath + "cache_info.nncache";
622     if (access(cacheInfoPath.c_str(), 0) != 0) {
623         ret = GenCacheBuild(preparedModel);
624         if (ret != OH_NN_SUCCESS) {
625             LOGE("Fail to build in generating cache mode.");
626             return ret;
627         }
628 
629         m_isBuild = true;
630         return OH_NN_SUCCESS;
631     }
632 
633     ModelCacheInfo cacheInfo;
634     ret = CheckCacheInfo(cacheInfo, cacheInfoPath);
635     if (ret != OH_NN_SUCCESS) {
636         return ret;
637     }
638 
639     if (m_version > cacheInfo.version) {
640         ret = ReGenCacheBuild(cacheInfo.fileNumber, preparedModel);
641         if (ret != OH_NN_SUCCESS) {
642             return ret;
643         }
644 
645         m_isBuild = true;
646         return OH_NN_SUCCESS;
647     }
648 
649     if (m_version < cacheInfo.version) {
650         LOGE("[Compilation] The current version is lower than the cache files, please set a higher version.");
651         return OH_NN_OPERATION_FORBIDDEN;
652     }
653 
654     ret = LoadCacheBuild(preparedModel, cacheInfo);
655     if (ret != OH_NN_SUCCESS) {
656         return ret;
657     }
658 
659     m_isBuild = true;
660     return OH_NN_SUCCESS;
661 }
662 
Build()663 OH_NN_ReturnCode Compilation::Build()
664 {
665     NNRT_TRACE_NAME("Compilation");
666     if (m_isBuild) {
667         LOGE("[Compilation] Cannot enable float16 after compilation finish.");
668         return OH_NN_OPERATION_FORBIDDEN;
669     }
670 
671     if (m_device == nullptr) {
672         LOGE("The parameter of m_device is nullptr, please call SetDevice function before build model.");
673         return OH_NN_OPERATION_FORBIDDEN;
674     }
675 
676     OH_NN_ReturnCode ret = InnerBuild();
677     if (ret != OH_NN_SUCCESS) {
678         return ret;
679     }
680 
681     return OH_NN_SUCCESS;
682 }
683 
GetExecutionPlan() const684 std::shared_ptr<ExecutionPlan> Compilation::GetExecutionPlan() const
685 {
686     return m_executionPlan;
687 }
688 
GetInputTensors() const689 std::vector<std::shared_ptr<NNTensor>> Compilation::GetInputTensors() const
690 {
691     return m_inputTensors;
692 }
693 
GetOutputTensors() const694 std::vector<std::shared_ptr<NNTensor>> Compilation::GetOutputTensors() const
695 {
696     return m_outputTensors;
697 }
698 
IsBuild() const699 bool Compilation::IsBuild() const
700 {
701     return m_isBuild;
702 }
703 
IsDynamicShape() const704 bool Compilation::IsDynamicShape() const
705 {
706     for (size_t i = 0; i < m_inputTensors.size(); ++i) {
707         if (m_inputTensors[i]->IsDynamicShape()) {
708             return true;
709         }
710     }
711     return false;
712 }
713 } // namespace NeuralNetworkRuntime
714 } // namespace OHOS