• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "compilation.h"
17 
18 #include <sys/stat.h>
19 #include <unistd.h>
20 #include <cstdio>
21 #include <sys/types.h>
22 #include <fstream>
23 #include <climits>
24 
25 #include "common/utils.h"
26 #include "common/scoped_trace.h"
27 #include "validation.h"
28 #include "device_manager.h"
29 
30 namespace OHOS {
31 namespace NeuralNetworkRuntime {
32 constexpr int MAX_MODEL_SIZE = 200 * 1024 * 1024; // 200MB
33 constexpr int OCT_UNIT = 8;
34 constexpr int NULL_PTR_LENGTH = 0;
35 constexpr int NUMBER_CACHE_INFO_MEMBERS = 3;
36 
37 // CRC16 Table is created based on the Polynomial of G(x) = x^16 + x^12 + x^15 + 1 and
38 // CRC register initialization value of "0" (0x0000)
39 static const unsigned short CRC16_TAB[256] = {
40     0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7,
41     0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef,
42     0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6,
43     0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de,
44     0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485,
45     0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d,
46     0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4,
47     0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc,
48     0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823,
49     0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b,
50     0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12,
51     0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a,
52     0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41,
53     0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49,
54     0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70,
55     0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78,
56     0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f,
57     0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067,
58     0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e,
59     0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256,
60     0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d,
61     0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405,
62     0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c,
63     0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634,
64     0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab,
65     0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3,
66     0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a,
67     0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92,
68     0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9,
69     0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1,
70     0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8,
71     0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0
72 };
73 
Compilation(const InnerModel * innerModel)74 Compilation::Compilation(const InnerModel* innerModel)
75     : m_liteGraph(innerModel->GetLiteGraphs()),
76     m_inputTensors(innerModel->GetInputTensors()),
77     m_outputTensors(innerModel->GetOutputTensors()),
78     m_metaGraph(innerModel->GetMetaGraph()),
79     m_quantBuffer(innerModel->GetQuantBuffer()),
80     m_modelName(innerModel->GetModelName()) {}
81 
SetDevice(size_t deviceId)82 OH_NN_ReturnCode Compilation::SetDevice(size_t deviceId)
83 {
84     if (m_isBuild) {
85         LOGE("Cannot set deviceId after compilation finish.");
86         return OH_NN_OPERATION_FORBIDDEN;
87     }
88 
89     auto& deviceManager = DeviceManager::GetInstance();
90     std::shared_ptr<Device> availableDevice = deviceManager.GetDevice(deviceId);
91     if (availableDevice == nullptr) {
92         LOGE("[Compilation] DeviceId does not exist, deviceId=%zu", deviceId);
93         return OH_NN_INVALID_PARAMETER;
94     }
95 
96     std::vector<bool> supportedList;
97     OH_NN_ReturnCode ret = availableDevice->GetSupportedOperation(m_liteGraph, supportedList);
98     if (ret != OH_NN_SUCCESS) {
99         LOGE("[Compilation] SetDevice failed, error happened when getting supported operation.");
100         return ret;
101     }
102 
103     for (bool isSupport : supportedList) {
104         if (!isSupport) {
105             LOGE("[Compilation] SetDevice failed, current device not support the model, device id: %zu.", deviceId);
106             return OH_NN_FAILED;
107         }
108     }
109 
110     bool supportDynamic;
111     ret = availableDevice->IsDynamicInputSupported(supportDynamic);
112     if (ret != OH_NN_SUCCESS) {
113         LOGE("[Compilation] SetDevice failed, error happened when checking whether device supports dynamic input.");
114         return ret;
115     }
116 
117     if (IsDynamicShape() && (!supportDynamic)) {
118         LOGE("[Compilation] SetDevice failed."
119              "The device does not support dynamic shape inputs, but the model has dynamic inputs.");
120         return OH_NN_FAILED;
121     }
122 
123     m_device = availableDevice;
124     m_deviceId = deviceId;
125     return OH_NN_SUCCESS;
126 }
127 
SetCacheDir(const std::string & cacheModelPath,uint32_t version)128 OH_NN_ReturnCode Compilation::SetCacheDir(const std::string& cacheModelPath, uint32_t version)
129 {
130     if (m_isBuild) {
131         LOGE("Cannot set cache after compilation finish.");
132         return OH_NN_OPERATION_FORBIDDEN;
133     }
134 
135     if (m_device == nullptr) {
136         LOGE("The parameter of m_device is nullptr, please call SetDevice function before calling SetCacheDir.");
137         return OH_NN_OPERATION_FORBIDDEN;
138     }
139 
140     bool isSupportedCache {false};
141     OH_NN_ReturnCode ret = m_device->IsModelCacheSupported(isSupportedCache);
142     if (ret != OH_NN_SUCCESS) {
143         LOGE("[Compilation] Fail to query whether the device is available to save cache model.");
144         return ret;
145     }
146 
147     if (!isSupportedCache) {
148         LOGE("[Compilation] The device is unavailable to save cache model.");
149         return OH_NN_OPERATION_FORBIDDEN;
150     }
151 
152     char realPathRes[PATH_MAX];
153     const char* filePath = realpath(cacheModelPath.c_str(), realPathRes);
154     if (filePath == nullptr) {
155         LOGE("[Compilation] The cache model path is invalid.");
156         return OH_NN_INVALID_PARAMETER;
157     }
158 
159     struct stat fileInfo;
160     if (stat(filePath, &fileInfo) != 0) {
161         LOGE("[Compilation] The cache directory does not exist or cannot be accessed.");
162         return OH_NN_INVALID_PARAMETER;
163     }
164 
165     if (!(fileInfo.st_mode & S_IFDIR)) {
166         LOGE("[Compilation] The cache model path is not a directory.");
167         return OH_NN_INVALID_PARAMETER;
168     }
169 
170     m_cachePath = (std::string)filePath + "/";
171     m_version = version;
172     return OH_NN_SUCCESS;
173 }
174 
SetPerformance(OH_NN_PerformanceMode performance)175 OH_NN_ReturnCode Compilation::SetPerformance(OH_NN_PerformanceMode performance)
176 {
177     if (m_isBuild) {
178         LOGE("[Compilation] Cannot set performance after compilation finish.");
179         return OH_NN_OPERATION_FORBIDDEN;
180     }
181 
182     if (m_device == nullptr) {
183         LOGE("Cannot set performance before set device, please set device first");
184         return OH_NN_OPERATION_FORBIDDEN;
185     }
186 
187     bool isSupportedPerformance {false};
188     OH_NN_ReturnCode ret = m_device->IsPerformanceModeSupported(isSupportedPerformance);
189     if (ret != OH_NN_SUCCESS) {
190         LOGE("[Compilation] Call device %zu failed.", m_deviceId);
191         return ret;
192     }
193 
194     if (!isSupportedPerformance) {
195         LOGE("[Compilation] This device %zu is not support performance setting.", m_deviceId);
196         return OH_NN_OPERATION_FORBIDDEN;
197     }
198 
199     if (!Validation::ValidatePerformanceMode(performance)) {
200         LOGE("[Compilation] SetPerformance passed invalid performance=%d", performance);
201         return OH_NN_INVALID_PARAMETER;
202     }
203 
204     m_performance = performance;
205     return OH_NN_SUCCESS;
206 }
207 
SetPriority(OH_NN_Priority priority)208 OH_NN_ReturnCode Compilation::SetPriority(OH_NN_Priority priority)
209 {
210     if (m_isBuild) {
211         LOGE("[Compilation] Cannot set priority after compilation finish.");
212         return OH_NN_OPERATION_FORBIDDEN;
213     }
214 
215     if (m_device == nullptr) {
216         LOGE("Cannot set priority before set device, please set device first");
217         return OH_NN_OPERATION_FORBIDDEN;
218     }
219 
220     bool isSupportedPriority {false};
221     OH_NN_ReturnCode ret = m_device->IsPrioritySupported(isSupportedPriority);
222     if (ret != OH_NN_SUCCESS) {
223         LOGE("[Compilation] Call device %zu failed.", m_deviceId);
224         return ret;
225     }
226 
227     if (!isSupportedPriority) {
228         LOGE("[Compilation] This device %zu is not support priority setting.", m_deviceId);
229         return OH_NN_OPERATION_FORBIDDEN;
230     }
231 
232     if (!Validation::ValidatePriority(priority)) {
233         LOGE("[Compilation] SetPriority passed invalid priority=%d", priority);
234         return OH_NN_INVALID_PARAMETER;
235     }
236 
237     m_priority = priority;
238     return OH_NN_SUCCESS;
239 }
240 
SetEnableFp16(bool isFp16)241 OH_NN_ReturnCode Compilation::SetEnableFp16(bool isFp16)
242 {
243     if (m_isBuild) {
244         LOGE("[Compilation] Cannot enable float16 after compilation finish.");
245         return OH_NN_OPERATION_FORBIDDEN;
246     }
247 
248     if (m_device == nullptr) {
249         LOGE("Cannot set enable fp16 before set device, please set device first");
250         return OH_NN_OPERATION_FORBIDDEN;
251     }
252 
253     bool isSupportedFp16 {false};
254     OH_NN_ReturnCode ret = m_device->IsFloat16PrecisionSupported(isSupportedFp16);
255     if (ret != OH_NN_SUCCESS) {
256         LOGE("[Compilation] Call device %zu failed.", m_deviceId);
257         return ret;
258     }
259 
260     if (!isSupportedFp16) {
261         LOGE("[Compilation] This device %zu is not support float16 precision setting.", m_deviceId);
262         return OH_NN_OPERATION_FORBIDDEN;
263     }
264 
265     m_enableFp16 = isFp16;
266     return OH_NN_SUCCESS;
267 }
268 
GetCrc16(const unsigned char * buffer,size_t length) const269 unsigned short Compilation::GetCrc16(const unsigned char* buffer, size_t length) const
270 {
271     unsigned short crc16 = 0;
272     for (size_t i = 0; i < length; ++i) {
273         uint8_t tableIndex = ((crc16 >> OCT_UNIT) ^ *buffer++) & 0x00ff;
274         crc16 = (crc16 << OCT_UNIT) ^ CRC16_TAB[tableIndex];
275     }
276     return crc16;
277 }
278 
GenerateCacheInfo(uint32_t cacheSize,std::unique_ptr<uint64_t[]> & cacheInfo) const279 OH_NN_ReturnCode Compilation::GenerateCacheInfo(uint32_t cacheSize, std::unique_ptr<uint64_t[]>& cacheInfo) const
280 {
281     std::string cacheInfoPath = m_cachePath + m_modelName + "cache_info.nncache";
282     std::ofstream cacheInfoStream(cacheInfoPath, std::ios::binary | std::ios::out | std::ios::trunc);
283     if (cacheInfoStream.fail()) {
284         LOGE("[Compilation] Model cache info file is invalid.");
285         return OH_NN_INVALID_FILE;
286     }
287 
288     if (!cacheInfoStream.write(reinterpret_cast<const char*>(cacheInfo.get()), cacheSize)) {
289         LOGE("[Compilation] Fail to write cache info.");
290         cacheInfoStream.close();
291         return OH_NN_FAILED;
292     }
293 
294     cacheInfoStream.close();
295     return OH_NN_SUCCESS;
296 }
297 
GenerateCacheModel(size_t cacheNumber,std::unique_ptr<uint64_t[]> & cacheInfo,std::vector<Buffer> modelBuffer) const298 OH_NN_ReturnCode Compilation::GenerateCacheModel(size_t cacheNumber, std::unique_ptr<uint64_t[]>& cacheInfo,
299     std::vector<Buffer> modelBuffer) const
300 {
301     auto cacheInfoPtr = cacheInfo.get();
302     *cacheInfoPtr++ = static_cast<uint64_t>(cacheNumber);
303     *cacheInfoPtr++ = static_cast<uint64_t>(m_version);
304     *cacheInfoPtr++ = static_cast<uint64_t>(m_deviceId);
305     for (uint32_t i = 0; i < cacheNumber; ++i) {
306         std::string cacheModelFile = m_cachePath + m_modelName + std::to_string(i) + ".nncache";
307         std::ofstream cacheModelStream(cacheModelFile, std::ios::binary | std::ios::out | std::ios::trunc);
308         if (cacheModelStream.fail()) {
309             LOGE("[Compilation] Model cache file is invalid.");
310             return OH_NN_INVALID_FILE;
311         }
312 
313         uint64_t checkSum = static_cast<uint64_t>(GetCrc16(static_cast<const unsigned char*>(modelBuffer[i].data),
314             modelBuffer[i].length));
315         *cacheInfoPtr++ = checkSum;
316         if (!cacheModelStream.write(static_cast<const char*>(modelBuffer[i].data), modelBuffer[i].length)) {
317             LOGE("[Compilation] Fail to write cache model.");
318             cacheModelStream.close();
319             return OH_NN_FAILED;
320         };
321 
322         cacheModelStream.close();
323     }
324 
325     return OH_NN_SUCCESS;
326 }
327 
GenerateCacheFiles(const std::vector<Buffer> & modelBuffer) const328 OH_NN_ReturnCode Compilation::GenerateCacheFiles(const std::vector<Buffer>& modelBuffer) const
329 {
330     const size_t cacheNumber = modelBuffer.size();
331     uint32_t cacheSize = NUMBER_CACHE_INFO_MEMBERS + cacheNumber;
332     std::unique_ptr<uint64_t[]> cacheInfo = std::make_unique<uint64_t[]>(cacheSize);
333     if (cacheInfo == nullptr) {
334         LOGE("Fail to create cacheInfo instance.");
335         return OH_NN_MEMORY_ERROR;
336     }
337 
338     OH_NN_ReturnCode ret = GenerateCacheModel(cacheNumber, cacheInfo, modelBuffer);
339     if (ret != OH_NN_SUCCESS) {
340         return ret;
341     }
342 
343     uint32_t infoCharNumber = cacheSize * sizeof(uint64_t);
344     ret = GenerateCacheInfo(infoCharNumber, cacheInfo);
345     if (ret != OH_NN_SUCCESS) {
346         return ret;
347     }
348 
349     return OH_NN_SUCCESS;
350 }
351 
GetCacheFileLength(std::ifstream & ifs,int & fsize) const352 OH_NN_ReturnCode Compilation::GetCacheFileLength(std::ifstream& ifs, int& fsize) const
353 {
354     ifs.seekg(0, std::ios::end);
355     if (!ifs.good()) {
356         LOGE("[Compilation] Fail to set the position of the next character to be extracted from the input stream.");
357         return OH_NN_INVALID_FILE;
358     }
359 
360     int handleValue = ifs.tellg();
361     if (handleValue == -1) {
362         LOGE("[Compilation] Unable to get position of the input stream.");
363         return OH_NN_INVALID_FILE;
364     }
365 
366     if ((handleValue > MAX_MODEL_SIZE) || (handleValue == NULL_PTR_LENGTH)) {
367         LOGE("[Compilation] Unable to read huge or empty input stream, get cache file size=%d", handleValue);
368         return OH_NN_INVALID_FILE;
369     }
370 
371     fsize = handleValue;
372     return OH_NN_SUCCESS;
373 }
374 
ReadCacheModelFile(const std::string & file,Buffer & modelBuffer) const375 OH_NN_ReturnCode Compilation::ReadCacheModelFile(const std::string& file, Buffer& modelBuffer) const
376 {
377     // file is validated outside.
378     std::ifstream ifs(file.c_str(), std::ios::in | std::ios::binary);
379     if (!ifs) {
380         LOGE("[Compilation] Fail to open cache file.");
381         return OH_NN_INVALID_FILE;
382     }
383 
384     int fsize {-1};
385     OH_NN_ReturnCode ret = GetCacheFileLength(ifs, fsize);
386     if (ret != OH_NN_SUCCESS) {
387         ifs.close();
388         return ret;
389     }
390 
391     ifs.seekg(0, std::ios::beg);
392     if (!ifs.good()) {
393         LOGE("[Compilation] Fail to set the position of the next character to be extracted"
394             "from the cache model stream.");
395         ifs.close();
396         return OH_NN_FAILED;
397     }
398 
399     char* ptr = static_cast<char*>(m_device->AllocateBuffer(fsize));
400     if (ptr == nullptr) {
401         LOGE("[Compilation] Fail to create file buffer.");
402         ifs.close();
403         return OH_NN_NULL_PTR;
404     }
405 
406     ifs.read(ptr, fsize);
407     if (!ifs.good()) {
408         LOGE("[Compilation] Fail to read the characters from the cache model stream.");
409         ifs.close();
410         m_device->ReleaseBuffer(ptr);
411         ptr = nullptr;
412         return OH_NN_FAILED;
413     }
414 
415     ifs.close();
416     modelBuffer.data = ptr;
417     modelBuffer.length = static_cast<size_t>(fsize); // fsize should be non-negative, safe to cast.
418     return OH_NN_SUCCESS;
419 }
420 
CheckCacheInfo(ModelCacheInfo & modelCacheInfo,const std::string & cacheInfoPath) const421 OH_NN_ReturnCode Compilation::CheckCacheInfo(ModelCacheInfo& modelCacheInfo, const std::string& cacheInfoPath) const
422 {
423     // cacheInfoPath is validated outside.
424     std::ifstream infoCacheFile(cacheInfoPath.c_str(), std::ios::in | std::ios::binary);
425     if (!infoCacheFile) {
426         LOGE("[Compilation] Opening cache info file failed.");
427         return OH_NN_INVALID_FILE;
428     }
429 
430     int charNumber = NUMBER_CACHE_INFO_MEMBERS * sizeof(uint64_t);
431     if (!infoCacheFile.read((char*)&(modelCacheInfo), charNumber)) {
432         LOGE("[Compilation] Fail to get the content of info cache file.");
433         infoCacheFile.close();
434         return OH_NN_INVALID_FILE;
435     }
436 
437     // modelCacheInfo.deviceId type is int64_t,
438     // it is transformed from size_t value, so the transform here will not truncate value.
439     size_t deviceId = static_cast<size_t>(modelCacheInfo.deviceId);
440     if (deviceId != m_deviceId) {
441         LOGE("[Compilation] The deviceId=%zu in the cache files is different from current deviceId=%zu,"
442             "please change the cache directory or current deviceId.", deviceId, m_deviceId);
443         infoCacheFile.close();
444         return OH_NN_INVALID_PARAMETER;
445     }
446 
447     std::vector<uint64_t> modelCheckSum;
448     modelCheckSum.resize(modelCacheInfo.fileNumber);
449     modelCacheInfo.modelCheckSum.resize(modelCacheInfo.fileNumber);
450     if (!infoCacheFile.read((char*)&modelCheckSum[0], modelCacheInfo.fileNumber * sizeof(uint64_t))) {
451         LOGE("[Compilation] The info cache file has been changed.");
452         infoCacheFile.close();
453         return OH_NN_INVALID_FILE;
454     }
455 
456     for (uint32_t i = 0; i < modelCacheInfo.fileNumber; ++i) {
457         modelCacheInfo.modelCheckSum[i] = static_cast<unsigned short>(modelCheckSum[i]);
458     }
459 
460     return OH_NN_SUCCESS;
461 }
462 
RemoveCacheFiles(uint32_t fileNumber) const463 OH_NN_ReturnCode Compilation::RemoveCacheFiles(uint32_t fileNumber) const
464 {
465     std::string cacheInfoPath = m_cachePath + m_modelName + "cache_info.nncache";
466     if (remove(cacheInfoPath.c_str()) == -1) {
467         LOGE("[Compilation] Fail to remove the file %s, please delete the file manually.", cacheInfoPath.c_str());
468         return OH_NN_FAILED;
469     }
470     LOGI("[Compilation] Succeed to remove the file cache_info.nncach.");
471 
472     for (uint32_t i = 0; i < fileNumber; ++i) {
473         std::string fileName = m_modelName + std::to_string(i) + ".nncache";
474         std::string cacheModelPath = m_cachePath + fileName;
475         if (access(cacheModelPath.c_str(), 0) != 0) {
476             LOGW("[Compilation] The file %s does not exist, no need to delete the file.", cacheModelPath.c_str());
477             continue;
478         }
479 
480         if (remove(cacheModelPath.c_str()) == -1) {
481             LOGE("[Compilation] Fail to remove the file %s, please delete the file manually.", cacheModelPath.c_str());
482             return OH_NN_FAILED;
483         }
484         LOGI("[Compilation] Succeed to remove the file %s", cacheModelPath.c_str());
485     }
486     return OH_NN_SUCCESS;
487 }
488 
CheckCacheModel(const ModelCacheInfo & modelCacheInfo,std::vector<Buffer> & modelBuffers) const489 OH_NN_ReturnCode Compilation::CheckCacheModel(const ModelCacheInfo& modelCacheInfo,
490     std::vector<Buffer>& modelBuffers) const
491 {
492     for (uint32_t i = 0; i < modelCacheInfo.fileNumber; ++i) {
493         std::string cacheModelPath = m_cachePath + m_modelName + std::to_string(i) + ".nncache";
494         if (access(cacheModelPath.c_str(), 0) != 0) {
495             LOGE("[Compilation] The cache model file %s does not exist.", cacheModelPath.c_str());
496             return OH_NN_INVALID_FILE;
497         }
498 
499         Buffer modelBuffer;
500         OH_NN_ReturnCode ret = ReadCacheModelFile(cacheModelPath, modelBuffer);
501         if (ret != OH_NN_SUCCESS) {
502             LOGE("[Compilation] Read cache model file failed.");
503             return ret;
504         }
505 
506         if (GetCrc16(static_cast<const unsigned char*>(modelBuffer.data),
507             modelBuffer.length) != modelCacheInfo.modelCheckSum[i]) {
508             LOGE("[Compilation] The cache model file %s has been changed.", cacheModelPath.c_str());
509             return OH_NN_INVALID_FILE;
510         }
511 
512         modelBuffers.emplace_back(std::move(modelBuffer));
513     }
514 
515     return OH_NN_SUCCESS;
516 }
517 
NormalBuild(std::shared_ptr<PreparedModel> & preparedModel)518 OH_NN_ReturnCode Compilation::NormalBuild(std::shared_ptr<PreparedModel>& preparedModel)
519 {
520     ModelConfig config {m_enableFp16, m_performance, m_priority};
521     if ((m_liteGraph == nullptr) && (m_metaGraph == nullptr)) {
522         LOGE("[Compilation] Both m_liteGraph and m_metaGraph are nullptr.");
523         return OH_NN_INVALID_PARAMETER;
524     }
525 
526     if ((m_liteGraph != nullptr) && (m_metaGraph != nullptr)) {
527         LOGE("[Compilation] Neither m_liteGraph nor m_metaGraph are nullptr.");
528         return OH_NN_INVALID_PARAMETER;
529     }
530 
531     OH_NN_ReturnCode ret {OH_NN_FAILED};
532     if (m_liteGraph != nullptr) {
533         ret = m_device->PrepareModel(m_liteGraph, config, preparedModel);
534     }
535     if (m_metaGraph != nullptr) {
536         ret = m_device->PrepareModel(m_metaGraph, m_quantBuffer, config, preparedModel);
537     }
538     if (ret != OH_NN_SUCCESS) {
539         LOGE("[Compilation] Preparing model failed when normally building.");
540         return ret;
541     }
542 
543     m_executionPlan = CreateSharedPtr<ExecutionPlan>(preparedModel, m_device);
544     if (m_executionPlan == nullptr) {
545         LOGE("[Compilation] Fail to create ExecutionPlan instance.");
546         return OH_NN_MEMORY_ERROR;
547     }
548 
549     return OH_NN_SUCCESS;
550 }
551 
GenCacheBuild(std::shared_ptr<PreparedModel> & preparedModel)552 OH_NN_ReturnCode Compilation::GenCacheBuild(std::shared_ptr<PreparedModel>& preparedModel)
553 {
554     OH_NN_ReturnCode ret = NormalBuild(preparedModel);
555     if (ret != OH_NN_SUCCESS) {
556         LOGE("[Compilation] Preparing model failed when generating cache.");
557         return ret;
558     }
559 
560     std::vector<Buffer> modelBuffers;
561     ret = preparedModel->ExportModelCache(modelBuffers);
562     if (ret != OH_NN_SUCCESS) {
563         LOGE("[Compilation] Export model cache failed.");
564         return ret;
565     }
566 
567     ret = GenerateCacheFiles(modelBuffers);
568     if (ret != OH_NN_SUCCESS) {
569         LOGE("[Compilation] Generate cache files failed.");
570         return ret;
571     }
572 
573     LOGI("[Compilation] Export model cache successfully.");
574     return OH_NN_SUCCESS;
575 }
576 
ReGenCacheBuild(uint32_t fileNumber,std::shared_ptr<PreparedModel> & preparedModel)577 OH_NN_ReturnCode Compilation::ReGenCacheBuild(uint32_t fileNumber, std::shared_ptr<PreparedModel>& preparedModel)
578 {
579     OH_NN_ReturnCode ret = RemoveCacheFiles(fileNumber);
580     if (ret != OH_NN_SUCCESS) {
581         return ret;
582     }
583 
584     ret = GenCacheBuild(preparedModel);
585     if (ret != OH_NN_SUCCESS) {
586         LOGE("[Compilation] Generating cache building failed.");
587         return ret;
588     }
589 
590     LOGI("[Compilation] Update model cache successfully.");
591     return OH_NN_SUCCESS;
592 }
593 
LoadCacheBuild(std::shared_ptr<PreparedModel> & preparedModel,const ModelCacheInfo & cacheInfo)594 OH_NN_ReturnCode Compilation::LoadCacheBuild(std::shared_ptr<PreparedModel>& preparedModel,
595     const ModelCacheInfo& cacheInfo)
596 {
597     std::vector<Buffer> modelBuffers;
598     OH_NN_ReturnCode ret = CheckCacheModel(cacheInfo, modelBuffers);
599     if (ret != OH_NN_SUCCESS) {
600         LOGE("[Compilation] Checking cache model failed.");
601         size_t modelBuffersSize = modelBuffers.size();
602         for (size_t i = 0; i < modelBuffersSize; ++i) {
603             m_device->ReleaseBuffer(modelBuffers[i].data);
604             modelBuffers[i].data = nullptr;
605             modelBuffers[i].length = 0;
606         }
607         return ret;
608     }
609 
610     ModelConfig config {m_enableFp16, m_performance, m_priority};
611     ret = m_device->PrepareModelFromModelCache(modelBuffers, config, preparedModel);
612     if (ret != OH_NN_SUCCESS) {
613         LOGE("[Compilation] Preparing model from cache failed.");
614         return ret;
615     }
616 
617     LOGI("[Compilation] Load cache successfully.");
618 
619     for (auto& modelBuffer : modelBuffers) {
620         ret = m_device->ReleaseBuffer(modelBuffer.data);
621         if (ret != OH_NN_SUCCESS) {
622             LOGE("[Compilation] Release cache model buffer failed.");
623             return ret;
624         }
625     }
626     modelBuffers.clear();
627 
628     m_executionPlan = CreateSharedPtr<ExecutionPlan>(preparedModel, m_device);
629     if (m_executionPlan == nullptr) {
630         LOGE("Fail to create ExecutionPlan instance.");
631         return OH_NN_MEMORY_ERROR;
632     }
633 
634     return OH_NN_SUCCESS;
635 }
636 
BuildCacheModel(std::shared_ptr<PreparedModel> & preparedModel)637 OH_NN_ReturnCode Compilation::BuildCacheModel(std::shared_ptr<PreparedModel>& preparedModel)
638 {
639     OH_NN_ReturnCode ret;
640     std::string cacheInfoPath = m_cachePath + m_modelName + "cache_info.nncache";
641     if (access(cacheInfoPath.c_str(), 0) != 0) {
642         ret = GenCacheBuild(preparedModel);
643         if (ret != OH_NN_SUCCESS) {
644             LOGE("Fail to build in generating cache mode.");
645             return ret;
646         }
647 
648         m_isBuild = true;
649         return OH_NN_SUCCESS;
650     }
651 
652     ModelCacheInfo cacheInfo;
653     ret = CheckCacheInfo(cacheInfo, cacheInfoPath);
654     if (ret != OH_NN_SUCCESS) {
655         return ret;
656     }
657 
658     if (m_version > cacheInfo.version) {
659         ret = ReGenCacheBuild(cacheInfo.fileNumber, preparedModel);
660         if (ret != OH_NN_SUCCESS) {
661             return ret;
662         }
663 
664         m_isBuild = true;
665         return OH_NN_SUCCESS;
666     }
667 
668     if (m_version < cacheInfo.version) {
669         LOGE("[Compilation] The current version is lower than the cache files, please set a higher version.");
670         return OH_NN_OPERATION_FORBIDDEN;
671     }
672 
673     ret = LoadCacheBuild(preparedModel, cacheInfo);
674     if (ret != OH_NN_SUCCESS) {
675         // recompile the model online and update the cache when failing to build cache model
676         ret = ReGenCacheBuild(cacheInfo.fileNumber, preparedModel);
677         if (ret != OH_NN_SUCCESS) {
678             LOGE("[Compilation] Failed to re-generate and build cache model.");
679             return ret;
680         }
681     }
682 
683     m_isBuild = true;
684 
685     return OH_NN_SUCCESS;
686 }
687 
InnerBuild()688 OH_NN_ReturnCode Compilation::InnerBuild()
689 {
690     OH_NN_ReturnCode ret;
691     std::shared_ptr<PreparedModel> preparedModel;
692 
693     // Prepare from offline model.
694     bool isOfflineModel {false};
695     ret = IsOfflineModel(isOfflineModel);
696     if (ret != OH_NN_SUCCESS) {
697         LOGE("[Compilation] Failed when identifying the offline model.");
698         return ret;
699     }
700 
701     if (isOfflineModel) {
702         ret = BuildOfflineModel(preparedModel);
703         if (ret != OH_NN_SUCCESS) {
704             LOGE("[Compilation] Failed to build offline model.");
705             return ret;
706         }
707 
708         m_isBuild = true;
709         return OH_NN_SUCCESS;
710     }
711 
712     if (m_cachePath.empty()) {
713         ret = NormalBuild(preparedModel);
714         if (ret != OH_NN_SUCCESS) {
715             LOGE("Fail to normally build.");
716             return ret;
717         }
718 
719         m_isBuild = true;
720         return OH_NN_SUCCESS;
721     }
722 
723     ret = BuildCacheModel(preparedModel);
724     if (ret != OH_NN_SUCCESS) {
725         LOGE("Fail to build cache model.");
726         return ret;
727     }
728 
729     return OH_NN_SUCCESS;
730 }
731 
Build()732 OH_NN_ReturnCode Compilation::Build()
733 {
734     NNRT_TRACE_NAME("Compilation");
735     if (m_isBuild) {
736         LOGE("[Compilation] Cannot enable float16 after compilation finish.");
737         return OH_NN_OPERATION_FORBIDDEN;
738     }
739 
740     if (m_device == nullptr) {
741         LOGE("The parameter of m_device is nullptr, please call SetDevice function before build model.");
742         return OH_NN_OPERATION_FORBIDDEN;
743     }
744 
745     OH_NN_ReturnCode ret = InnerBuild();
746     if (ret != OH_NN_SUCCESS) {
747         return ret;
748     }
749 
750     return OH_NN_SUCCESS;
751 }
752 
GetExecutionPlan() const753 std::shared_ptr<ExecutionPlan> Compilation::GetExecutionPlan() const
754 {
755     return m_executionPlan;
756 }
757 
GetInputTensors() const758 std::vector<std::shared_ptr<NNTensor>> Compilation::GetInputTensors() const
759 {
760     return m_inputTensors;
761 }
762 
GetOutputTensors() const763 std::vector<std::shared_ptr<NNTensor>> Compilation::GetOutputTensors() const
764 {
765     return m_outputTensors;
766 }
767 
IsBuild() const768 bool Compilation::IsBuild() const
769 {
770     return m_isBuild;
771 }
772 
IsDynamicShape() const773 bool Compilation::IsDynamicShape() const
774 {
775     size_t inputTensorsSize = m_inputTensors.size();
776     for (size_t i = 0; i < inputTensorsSize; ++i) {
777         if (m_inputTensors[i]->IsDynamicShape()) {
778             return true;
779         }
780     }
781     return false;
782 }
783 
IsOfflineModel(bool & isOfflineModel) const784 OH_NN_ReturnCode Compilation::IsOfflineModel(bool& isOfflineModel) const
785 {
786     isOfflineModel = false; // Initialize the returned value
787     if ((m_liteGraph == nullptr) && (m_metaGraph == nullptr)) {
788         LOGE("[Compilation] LiteGraph and metaGraph are empty when identifying the offline model.");
789         return OH_NN_NULL_PTR;
790     }
791 
792     if ((m_liteGraph != nullptr) && (m_metaGraph != nullptr)) {
793         LOGE("[Compilation] LiteGraph and metaGraph are not empty when identifying the offline model.");
794         return OH_NN_INVALID_PARAMETER;
795     }
796 
797     if (m_metaGraph != nullptr) {
798         isOfflineModel = false;
799         return OH_NN_SUCCESS;
800     }
801 
802     if (m_liteGraph->all_nodes_.size() == 0) {
803         LOGE("[Compilation] Find empty node in the model.");
804         return OH_NN_INVALID_PARAMETER;
805     }
806 
807     // If the model consists of more than 1 node, it will not be considered as offline model.
808     if (m_liteGraph->all_nodes_.size() > 1) {
809         isOfflineModel = false;
810         return OH_NN_SUCCESS;
811     }
812 
813     const mindspore::lite::LiteGraph::Node* pNode = m_liteGraph->all_nodes_[0];
814     if (pNode == nullptr) {
815         LOGE("[Compilation] Find invalid node in the model.");
816         return OH_NN_NULL_PTR;
817     }
818 
819     const mindspore::lite::NodeType& nodeType = mindspore::lite::MindIR_Primitive_GetType(pNode->primitive_);
820     if (nodeType == mindspore::lite::NodeType::NODE_TYPE_CUSTOM) {
821         isOfflineModel = true;
822     }
823 
824     return OH_NN_SUCCESS;
825 }
826 
BuildOfflineModel(std::shared_ptr<PreparedModel> & preparedModel)827 OH_NN_ReturnCode Compilation::BuildOfflineModel(std::shared_ptr<PreparedModel>& preparedModel)
828 {
829     ModelConfig config {m_enableFp16, m_performance, m_priority};
830     OH_NN_ReturnCode ret = m_device->PrepareOfflineModel(m_liteGraph, config, preparedModel);
831     if (ret != OH_NN_SUCCESS) {
832         LOGE("[Compilation] Preparing model failed when building from offline model.");
833         return ret;
834     }
835 
836     m_executionPlan = CreateSharedPtr<ExecutionPlan>(preparedModel, m_device);
837     if (m_executionPlan == nullptr) {
838         LOGE("[Compilation] Failed to create ExecutionPlan when building from offline model.");
839         return OH_NN_MEMORY_ERROR;
840     }
841 
842     return OH_NN_SUCCESS;
843 }
844 } // namespace NeuralNetworkRuntime
845 } // namespace OHOS