1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "compilation.h"
17
18 #include <sys/stat.h>
19 #include <unistd.h>
20 #include <cstdio>
21 #include <sys/types.h>
22 #include <fstream>
23 #include <climits>
24
25 #include "common/utils.h"
26 #include "common/scoped_trace.h"
27 #include "validation.h"
28 #include "device_manager.h"
29
30 namespace OHOS {
31 namespace NeuralNetworkRuntime {
32 constexpr int MAX_MODEL_SIZE = 200 * 1024 * 1024; // 200MB
33 constexpr int OCT_UNIT = 8;
34 constexpr int NULL_PTR_LENGTH = 0;
35 constexpr int NUMBER_CACHE_INFO_MEMBERS = 3;
36
37 // CRC16 Table is created based on the Polynomial of G(x) = x^16 + x^12 + x^15 + 1 and
38 // CRC register initialization value of "0" (0x0000)
39 static const unsigned short CRC16_TAB[256] = {
40 0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7,
41 0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef,
42 0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6,
43 0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de,
44 0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485,
45 0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d,
46 0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4,
47 0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc,
48 0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823,
49 0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b,
50 0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12,
51 0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a,
52 0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41,
53 0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49,
54 0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70,
55 0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78,
56 0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f,
57 0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067,
58 0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e,
59 0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256,
60 0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d,
61 0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405,
62 0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c,
63 0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634,
64 0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab,
65 0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3,
66 0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a,
67 0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92,
68 0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9,
69 0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1,
70 0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8,
71 0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0
72 };
73
Compilation(const InnerModel * innerModel)74 Compilation::Compilation(const InnerModel* innerModel)
75 : m_liteGraph(innerModel->GetLiteGraphs()),
76 m_inputTensors(innerModel->GetInputTensors()),
77 m_outputTensors(innerModel->GetOutputTensors()),
78 m_metaGraph(innerModel->GetMetaGraph()),
79 m_quantBuffer(innerModel->GetQuantBuffer()),
80 m_modelName(innerModel->GetModelName()) {}
81
SetDevice(size_t deviceId)82 OH_NN_ReturnCode Compilation::SetDevice(size_t deviceId)
83 {
84 if (m_isBuild) {
85 LOGE("Cannot set deviceId after compilation finish.");
86 return OH_NN_OPERATION_FORBIDDEN;
87 }
88
89 auto& deviceManager = DeviceManager::GetInstance();
90 std::shared_ptr<Device> availableDevice = deviceManager.GetDevice(deviceId);
91 if (availableDevice == nullptr) {
92 LOGE("[Compilation] DeviceId does not exist, deviceId=%zu", deviceId);
93 return OH_NN_INVALID_PARAMETER;
94 }
95
96 std::vector<bool> supportedList;
97 OH_NN_ReturnCode ret = availableDevice->GetSupportedOperation(m_liteGraph, supportedList);
98 if (ret != OH_NN_SUCCESS) {
99 LOGE("[Compilation] SetDevice failed, error happened when getting supported operation.");
100 return ret;
101 }
102
103 for (bool isSupport : supportedList) {
104 if (!isSupport) {
105 LOGE("[Compilation] SetDevice failed, current device not support the model, device id: %zu.", deviceId);
106 return OH_NN_FAILED;
107 }
108 }
109
110 bool supportDynamic;
111 ret = availableDevice->IsDynamicInputSupported(supportDynamic);
112 if (ret != OH_NN_SUCCESS) {
113 LOGE("[Compilation] SetDevice failed, error happened when checking whether device supports dynamic input.");
114 return ret;
115 }
116
117 if (IsDynamicShape() && (!supportDynamic)) {
118 LOGE("[Compilation] SetDevice failed."
119 "The device does not support dynamic shape inputs, but the model has dynamic inputs.");
120 return OH_NN_FAILED;
121 }
122
123 m_device = availableDevice;
124 m_deviceId = deviceId;
125 return OH_NN_SUCCESS;
126 }
127
SetCacheDir(const std::string & cacheModelPath,uint32_t version)128 OH_NN_ReturnCode Compilation::SetCacheDir(const std::string& cacheModelPath, uint32_t version)
129 {
130 if (m_isBuild) {
131 LOGE("Cannot set cache after compilation finish.");
132 return OH_NN_OPERATION_FORBIDDEN;
133 }
134
135 if (m_device == nullptr) {
136 LOGE("The parameter of m_device is nullptr, please call SetDevice function before calling SetCacheDir.");
137 return OH_NN_OPERATION_FORBIDDEN;
138 }
139
140 bool isSupportedCache {false};
141 OH_NN_ReturnCode ret = m_device->IsModelCacheSupported(isSupportedCache);
142 if (ret != OH_NN_SUCCESS) {
143 LOGE("[Compilation] Fail to query whether the device is available to save cache model.");
144 return ret;
145 }
146
147 if (!isSupportedCache) {
148 LOGE("[Compilation] The device is unavailable to save cache model.");
149 return OH_NN_OPERATION_FORBIDDEN;
150 }
151
152 char realPathRes[PATH_MAX];
153 const char* filePath = realpath(cacheModelPath.c_str(), realPathRes);
154 if (filePath == nullptr) {
155 LOGE("[Compilation] The cache model path is invalid.");
156 return OH_NN_INVALID_PARAMETER;
157 }
158
159 struct stat fileInfo;
160 if (stat(filePath, &fileInfo) != 0) {
161 LOGE("[Compilation] The cache directory does not exist or cannot be accessed.");
162 return OH_NN_INVALID_PARAMETER;
163 }
164
165 if (!(fileInfo.st_mode & S_IFDIR)) {
166 LOGE("[Compilation] The cache model path is not a directory.");
167 return OH_NN_INVALID_PARAMETER;
168 }
169
170 m_cachePath = (std::string)filePath + "/";
171 m_version = version;
172 return OH_NN_SUCCESS;
173 }
174
SetPerformance(OH_NN_PerformanceMode performance)175 OH_NN_ReturnCode Compilation::SetPerformance(OH_NN_PerformanceMode performance)
176 {
177 if (m_isBuild) {
178 LOGE("[Compilation] Cannot set performance after compilation finish.");
179 return OH_NN_OPERATION_FORBIDDEN;
180 }
181
182 if (m_device == nullptr) {
183 LOGE("Cannot set performance before set device, please set device first");
184 return OH_NN_OPERATION_FORBIDDEN;
185 }
186
187 bool isSupportedPerformance {false};
188 OH_NN_ReturnCode ret = m_device->IsPerformanceModeSupported(isSupportedPerformance);
189 if (ret != OH_NN_SUCCESS) {
190 LOGE("[Compilation] Call device %zu failed.", m_deviceId);
191 return ret;
192 }
193
194 if (!isSupportedPerformance) {
195 LOGE("[Compilation] This device %zu is not support performance setting.", m_deviceId);
196 return OH_NN_OPERATION_FORBIDDEN;
197 }
198
199 if (!Validation::ValidatePerformanceMode(performance)) {
200 LOGE("[Compilation] SetPerformance passed invalid performance=%d", performance);
201 return OH_NN_INVALID_PARAMETER;
202 }
203
204 m_performance = performance;
205 return OH_NN_SUCCESS;
206 }
207
SetPriority(OH_NN_Priority priority)208 OH_NN_ReturnCode Compilation::SetPriority(OH_NN_Priority priority)
209 {
210 if (m_isBuild) {
211 LOGE("[Compilation] Cannot set priority after compilation finish.");
212 return OH_NN_OPERATION_FORBIDDEN;
213 }
214
215 if (m_device == nullptr) {
216 LOGE("Cannot set priority before set device, please set device first");
217 return OH_NN_OPERATION_FORBIDDEN;
218 }
219
220 bool isSupportedPriority {false};
221 OH_NN_ReturnCode ret = m_device->IsPrioritySupported(isSupportedPriority);
222 if (ret != OH_NN_SUCCESS) {
223 LOGE("[Compilation] Call device %zu failed.", m_deviceId);
224 return ret;
225 }
226
227 if (!isSupportedPriority) {
228 LOGE("[Compilation] This device %zu is not support priority setting.", m_deviceId);
229 return OH_NN_OPERATION_FORBIDDEN;
230 }
231
232 if (!Validation::ValidatePriority(priority)) {
233 LOGE("[Compilation] SetPriority passed invalid priority=%d", priority);
234 return OH_NN_INVALID_PARAMETER;
235 }
236
237 m_priority = priority;
238 return OH_NN_SUCCESS;
239 }
240
SetEnableFp16(bool isFp16)241 OH_NN_ReturnCode Compilation::SetEnableFp16(bool isFp16)
242 {
243 if (m_isBuild) {
244 LOGE("[Compilation] Cannot enable float16 after compilation finish.");
245 return OH_NN_OPERATION_FORBIDDEN;
246 }
247
248 if (m_device == nullptr) {
249 LOGE("Cannot set enable fp16 before set device, please set device first");
250 return OH_NN_OPERATION_FORBIDDEN;
251 }
252
253 bool isSupportedFp16 {false};
254 OH_NN_ReturnCode ret = m_device->IsFloat16PrecisionSupported(isSupportedFp16);
255 if (ret != OH_NN_SUCCESS) {
256 LOGE("[Compilation] Call device %zu failed.", m_deviceId);
257 return ret;
258 }
259
260 if (!isSupportedFp16) {
261 LOGE("[Compilation] This device %zu is not support float16 precision setting.", m_deviceId);
262 return OH_NN_OPERATION_FORBIDDEN;
263 }
264
265 m_enableFp16 = isFp16;
266 return OH_NN_SUCCESS;
267 }
268
GetCrc16(const unsigned char * buffer,size_t length) const269 unsigned short Compilation::GetCrc16(const unsigned char* buffer, size_t length) const
270 {
271 unsigned short crc16 = 0;
272 for (size_t i = 0; i < length; ++i) {
273 uint8_t tableIndex = ((crc16 >> OCT_UNIT) ^ *buffer++) & 0x00ff;
274 crc16 = (crc16 << OCT_UNIT) ^ CRC16_TAB[tableIndex];
275 }
276 return crc16;
277 }
278
GenerateCacheInfo(uint32_t cacheSize,std::unique_ptr<uint64_t[]> & cacheInfo) const279 OH_NN_ReturnCode Compilation::GenerateCacheInfo(uint32_t cacheSize, std::unique_ptr<uint64_t[]>& cacheInfo) const
280 {
281 std::string cacheInfoPath = m_cachePath + m_modelName + "cache_info.nncache";
282 std::ofstream cacheInfoStream(cacheInfoPath, std::ios::binary | std::ios::out | std::ios::trunc);
283 if (cacheInfoStream.fail()) {
284 LOGE("[Compilation] Model cache info file is invalid.");
285 return OH_NN_INVALID_FILE;
286 }
287
288 if (!cacheInfoStream.write(reinterpret_cast<const char*>(cacheInfo.get()), cacheSize)) {
289 LOGE("[Compilation] Fail to write cache info.");
290 cacheInfoStream.close();
291 return OH_NN_FAILED;
292 }
293
294 cacheInfoStream.close();
295 return OH_NN_SUCCESS;
296 }
297
GenerateCacheModel(size_t cacheNumber,std::unique_ptr<uint64_t[]> & cacheInfo,std::vector<Buffer> modelBuffer) const298 OH_NN_ReturnCode Compilation::GenerateCacheModel(size_t cacheNumber, std::unique_ptr<uint64_t[]>& cacheInfo,
299 std::vector<Buffer> modelBuffer) const
300 {
301 auto cacheInfoPtr = cacheInfo.get();
302 *cacheInfoPtr++ = static_cast<uint64_t>(cacheNumber);
303 *cacheInfoPtr++ = static_cast<uint64_t>(m_version);
304 *cacheInfoPtr++ = static_cast<uint64_t>(m_deviceId);
305 for (uint32_t i = 0; i < cacheNumber; ++i) {
306 std::string cacheModelFile = m_cachePath + m_modelName + std::to_string(i) + ".nncache";
307 std::ofstream cacheModelStream(cacheModelFile, std::ios::binary | std::ios::out | std::ios::trunc);
308 if (cacheModelStream.fail()) {
309 LOGE("[Compilation] Model cache file is invalid.");
310 return OH_NN_INVALID_FILE;
311 }
312
313 uint64_t checkSum = static_cast<uint64_t>(GetCrc16(static_cast<const unsigned char*>(modelBuffer[i].data),
314 modelBuffer[i].length));
315 *cacheInfoPtr++ = checkSum;
316 if (!cacheModelStream.write(static_cast<const char*>(modelBuffer[i].data), modelBuffer[i].length)) {
317 LOGE("[Compilation] Fail to write cache model.");
318 cacheModelStream.close();
319 return OH_NN_FAILED;
320 };
321
322 cacheModelStream.close();
323 }
324
325 return OH_NN_SUCCESS;
326 }
327
GenerateCacheFiles(const std::vector<Buffer> & modelBuffer) const328 OH_NN_ReturnCode Compilation::GenerateCacheFiles(const std::vector<Buffer>& modelBuffer) const
329 {
330 const size_t cacheNumber = modelBuffer.size();
331 uint32_t cacheSize = NUMBER_CACHE_INFO_MEMBERS + cacheNumber;
332 std::unique_ptr<uint64_t[]> cacheInfo = std::make_unique<uint64_t[]>(cacheSize);
333 if (cacheInfo == nullptr) {
334 LOGE("Fail to create cacheInfo instance.");
335 return OH_NN_MEMORY_ERROR;
336 }
337
338 OH_NN_ReturnCode ret = GenerateCacheModel(cacheNumber, cacheInfo, modelBuffer);
339 if (ret != OH_NN_SUCCESS) {
340 return ret;
341 }
342
343 uint32_t infoCharNumber = cacheSize * sizeof(uint64_t);
344 ret = GenerateCacheInfo(infoCharNumber, cacheInfo);
345 if (ret != OH_NN_SUCCESS) {
346 return ret;
347 }
348
349 return OH_NN_SUCCESS;
350 }
351
GetCacheFileLength(std::ifstream & ifs,int & fsize) const352 OH_NN_ReturnCode Compilation::GetCacheFileLength(std::ifstream& ifs, int& fsize) const
353 {
354 ifs.seekg(0, std::ios::end);
355 if (!ifs.good()) {
356 LOGE("[Compilation] Fail to set the position of the next character to be extracted from the input stream.");
357 return OH_NN_INVALID_FILE;
358 }
359
360 int handleValue = ifs.tellg();
361 if (handleValue == -1) {
362 LOGE("[Compilation] Unable to get position of the input stream.");
363 return OH_NN_INVALID_FILE;
364 }
365
366 if ((handleValue > MAX_MODEL_SIZE) || (handleValue == NULL_PTR_LENGTH)) {
367 LOGE("[Compilation] Unable to read huge or empty input stream, get cache file size=%d", handleValue);
368 return OH_NN_INVALID_FILE;
369 }
370
371 fsize = handleValue;
372 return OH_NN_SUCCESS;
373 }
374
ReadCacheModelFile(const std::string & file,Buffer & modelBuffer) const375 OH_NN_ReturnCode Compilation::ReadCacheModelFile(const std::string& file, Buffer& modelBuffer) const
376 {
377 // file is validated outside.
378 std::ifstream ifs(file.c_str(), std::ios::in | std::ios::binary);
379 if (!ifs) {
380 LOGE("[Compilation] Fail to open cache file.");
381 return OH_NN_INVALID_FILE;
382 }
383
384 int fsize {-1};
385 OH_NN_ReturnCode ret = GetCacheFileLength(ifs, fsize);
386 if (ret != OH_NN_SUCCESS) {
387 ifs.close();
388 return ret;
389 }
390
391 ifs.seekg(0, std::ios::beg);
392 if (!ifs.good()) {
393 LOGE("[Compilation] Fail to set the position of the next character to be extracted"
394 "from the cache model stream.");
395 ifs.close();
396 return OH_NN_FAILED;
397 }
398
399 char* ptr = static_cast<char*>(m_device->AllocateBuffer(fsize));
400 if (ptr == nullptr) {
401 LOGE("[Compilation] Fail to create file buffer.");
402 ifs.close();
403 return OH_NN_NULL_PTR;
404 }
405
406 ifs.read(ptr, fsize);
407 if (!ifs.good()) {
408 LOGE("[Compilation] Fail to read the characters from the cache model stream.");
409 ifs.close();
410 m_device->ReleaseBuffer(ptr);
411 ptr = nullptr;
412 return OH_NN_FAILED;
413 }
414
415 ifs.close();
416 modelBuffer.data = ptr;
417 modelBuffer.length = static_cast<size_t>(fsize); // fsize should be non-negative, safe to cast.
418 return OH_NN_SUCCESS;
419 }
420
CheckCacheInfo(ModelCacheInfo & modelCacheInfo,const std::string & cacheInfoPath) const421 OH_NN_ReturnCode Compilation::CheckCacheInfo(ModelCacheInfo& modelCacheInfo, const std::string& cacheInfoPath) const
422 {
423 // cacheInfoPath is validated outside.
424 std::ifstream infoCacheFile(cacheInfoPath.c_str(), std::ios::in | std::ios::binary);
425 if (!infoCacheFile) {
426 LOGE("[Compilation] Opening cache info file failed.");
427 return OH_NN_INVALID_FILE;
428 }
429
430 int charNumber = NUMBER_CACHE_INFO_MEMBERS * sizeof(uint64_t);
431 if (!infoCacheFile.read((char*)&(modelCacheInfo), charNumber)) {
432 LOGE("[Compilation] Fail to get the content of info cache file.");
433 infoCacheFile.close();
434 return OH_NN_INVALID_FILE;
435 }
436
437 // modelCacheInfo.deviceId type is int64_t,
438 // it is transformed from size_t value, so the transform here will not truncate value.
439 size_t deviceId = static_cast<size_t>(modelCacheInfo.deviceId);
440 if (deviceId != m_deviceId) {
441 LOGE("[Compilation] The deviceId=%zu in the cache files is different from current deviceId=%zu,"
442 "please change the cache directory or current deviceId.", deviceId, m_deviceId);
443 infoCacheFile.close();
444 return OH_NN_INVALID_PARAMETER;
445 }
446
447 std::vector<uint64_t> modelCheckSum;
448 modelCheckSum.resize(modelCacheInfo.fileNumber);
449 modelCacheInfo.modelCheckSum.resize(modelCacheInfo.fileNumber);
450 if (!infoCacheFile.read((char*)&modelCheckSum[0], modelCacheInfo.fileNumber * sizeof(uint64_t))) {
451 LOGE("[Compilation] The info cache file has been changed.");
452 infoCacheFile.close();
453 return OH_NN_INVALID_FILE;
454 }
455
456 for (uint32_t i = 0; i < modelCacheInfo.fileNumber; ++i) {
457 modelCacheInfo.modelCheckSum[i] = static_cast<unsigned short>(modelCheckSum[i]);
458 }
459
460 return OH_NN_SUCCESS;
461 }
462
RemoveCacheFiles(uint32_t fileNumber) const463 OH_NN_ReturnCode Compilation::RemoveCacheFiles(uint32_t fileNumber) const
464 {
465 std::string cacheInfoPath = m_cachePath + m_modelName + "cache_info.nncache";
466 if (remove(cacheInfoPath.c_str()) == -1) {
467 LOGE("[Compilation] Fail to remove the file %s, please delete the file manually.", cacheInfoPath.c_str());
468 return OH_NN_FAILED;
469 }
470 LOGI("[Compilation] Succeed to remove the file cache_info.nncach.");
471
472 for (uint32_t i = 0; i < fileNumber; ++i) {
473 std::string fileName = m_modelName + std::to_string(i) + ".nncache";
474 std::string cacheModelPath = m_cachePath + fileName;
475 if (access(cacheModelPath.c_str(), 0) != 0) {
476 LOGW("[Compilation] The file %s does not exist, no need to delete the file.", cacheModelPath.c_str());
477 continue;
478 }
479
480 if (remove(cacheModelPath.c_str()) == -1) {
481 LOGE("[Compilation] Fail to remove the file %s, please delete the file manually.", cacheModelPath.c_str());
482 return OH_NN_FAILED;
483 }
484 LOGI("[Compilation] Succeed to remove the file %s", cacheModelPath.c_str());
485 }
486 return OH_NN_SUCCESS;
487 }
488
CheckCacheModel(const ModelCacheInfo & modelCacheInfo,std::vector<Buffer> & modelBuffers) const489 OH_NN_ReturnCode Compilation::CheckCacheModel(const ModelCacheInfo& modelCacheInfo,
490 std::vector<Buffer>& modelBuffers) const
491 {
492 for (uint32_t i = 0; i < modelCacheInfo.fileNumber; ++i) {
493 std::string cacheModelPath = m_cachePath + m_modelName + std::to_string(i) + ".nncache";
494 if (access(cacheModelPath.c_str(), 0) != 0) {
495 LOGE("[Compilation] The cache model file %s does not exist.", cacheModelPath.c_str());
496 return OH_NN_INVALID_FILE;
497 }
498
499 Buffer modelBuffer;
500 OH_NN_ReturnCode ret = ReadCacheModelFile(cacheModelPath, modelBuffer);
501 if (ret != OH_NN_SUCCESS) {
502 LOGE("[Compilation] Read cache model file failed.");
503 return ret;
504 }
505
506 if (GetCrc16(static_cast<const unsigned char*>(modelBuffer.data),
507 modelBuffer.length) != modelCacheInfo.modelCheckSum[i]) {
508 LOGE("[Compilation] The cache model file %s has been changed.", cacheModelPath.c_str());
509 return OH_NN_INVALID_FILE;
510 }
511
512 modelBuffers.emplace_back(std::move(modelBuffer));
513 }
514
515 return OH_NN_SUCCESS;
516 }
517
NormalBuild(std::shared_ptr<PreparedModel> & preparedModel)518 OH_NN_ReturnCode Compilation::NormalBuild(std::shared_ptr<PreparedModel>& preparedModel)
519 {
520 ModelConfig config {m_enableFp16, m_performance, m_priority};
521 if ((m_liteGraph == nullptr) && (m_metaGraph == nullptr)) {
522 LOGE("[Compilation] Both m_liteGraph and m_metaGraph are nullptr.");
523 return OH_NN_INVALID_PARAMETER;
524 }
525
526 if ((m_liteGraph != nullptr) && (m_metaGraph != nullptr)) {
527 LOGE("[Compilation] Neither m_liteGraph nor m_metaGraph are nullptr.");
528 return OH_NN_INVALID_PARAMETER;
529 }
530
531 OH_NN_ReturnCode ret {OH_NN_FAILED};
532 if (m_liteGraph != nullptr) {
533 ret = m_device->PrepareModel(m_liteGraph, config, preparedModel);
534 }
535 if (m_metaGraph != nullptr) {
536 ret = m_device->PrepareModel(m_metaGraph, m_quantBuffer, config, preparedModel);
537 }
538 if (ret != OH_NN_SUCCESS) {
539 LOGE("[Compilation] Preparing model failed when normally building.");
540 return ret;
541 }
542
543 m_executionPlan = CreateSharedPtr<ExecutionPlan>(preparedModel, m_device);
544 if (m_executionPlan == nullptr) {
545 LOGE("[Compilation] Fail to create ExecutionPlan instance.");
546 return OH_NN_MEMORY_ERROR;
547 }
548
549 return OH_NN_SUCCESS;
550 }
551
GenCacheBuild(std::shared_ptr<PreparedModel> & preparedModel)552 OH_NN_ReturnCode Compilation::GenCacheBuild(std::shared_ptr<PreparedModel>& preparedModel)
553 {
554 OH_NN_ReturnCode ret = NormalBuild(preparedModel);
555 if (ret != OH_NN_SUCCESS) {
556 LOGE("[Compilation] Preparing model failed when generating cache.");
557 return ret;
558 }
559
560 std::vector<Buffer> modelBuffers;
561 ret = preparedModel->ExportModelCache(modelBuffers);
562 if (ret != OH_NN_SUCCESS) {
563 LOGE("[Compilation] Export model cache failed.");
564 return ret;
565 }
566
567 ret = GenerateCacheFiles(modelBuffers);
568 if (ret != OH_NN_SUCCESS) {
569 LOGE("[Compilation] Generate cache files failed.");
570 return ret;
571 }
572
573 LOGI("[Compilation] Export model cache successfully.");
574 return OH_NN_SUCCESS;
575 }
576
ReGenCacheBuild(uint32_t fileNumber,std::shared_ptr<PreparedModel> & preparedModel)577 OH_NN_ReturnCode Compilation::ReGenCacheBuild(uint32_t fileNumber, std::shared_ptr<PreparedModel>& preparedModel)
578 {
579 OH_NN_ReturnCode ret = RemoveCacheFiles(fileNumber);
580 if (ret != OH_NN_SUCCESS) {
581 return ret;
582 }
583
584 ret = GenCacheBuild(preparedModel);
585 if (ret != OH_NN_SUCCESS) {
586 LOGE("[Compilation] Generating cache building failed.");
587 return ret;
588 }
589
590 LOGI("[Compilation] Update model cache successfully.");
591 return OH_NN_SUCCESS;
592 }
593
LoadCacheBuild(std::shared_ptr<PreparedModel> & preparedModel,const ModelCacheInfo & cacheInfo)594 OH_NN_ReturnCode Compilation::LoadCacheBuild(std::shared_ptr<PreparedModel>& preparedModel,
595 const ModelCacheInfo& cacheInfo)
596 {
597 std::vector<Buffer> modelBuffers;
598 OH_NN_ReturnCode ret = CheckCacheModel(cacheInfo, modelBuffers);
599 if (ret != OH_NN_SUCCESS) {
600 LOGE("[Compilation] Checking cache model failed.");
601 size_t modelBuffersSize = modelBuffers.size();
602 for (size_t i = 0; i < modelBuffersSize; ++i) {
603 m_device->ReleaseBuffer(modelBuffers[i].data);
604 modelBuffers[i].data = nullptr;
605 modelBuffers[i].length = 0;
606 }
607 return ret;
608 }
609
610 ModelConfig config {m_enableFp16, m_performance, m_priority};
611 ret = m_device->PrepareModelFromModelCache(modelBuffers, config, preparedModel);
612 if (ret != OH_NN_SUCCESS) {
613 LOGE("[Compilation] Preparing model from cache failed.");
614 return ret;
615 }
616
617 LOGI("[Compilation] Load cache successfully.");
618
619 for (auto& modelBuffer : modelBuffers) {
620 ret = m_device->ReleaseBuffer(modelBuffer.data);
621 if (ret != OH_NN_SUCCESS) {
622 LOGE("[Compilation] Release cache model buffer failed.");
623 return ret;
624 }
625 }
626 modelBuffers.clear();
627
628 m_executionPlan = CreateSharedPtr<ExecutionPlan>(preparedModel, m_device);
629 if (m_executionPlan == nullptr) {
630 LOGE("Fail to create ExecutionPlan instance.");
631 return OH_NN_MEMORY_ERROR;
632 }
633
634 return OH_NN_SUCCESS;
635 }
636
BuildCacheModel(std::shared_ptr<PreparedModel> & preparedModel)637 OH_NN_ReturnCode Compilation::BuildCacheModel(std::shared_ptr<PreparedModel>& preparedModel)
638 {
639 OH_NN_ReturnCode ret;
640 std::string cacheInfoPath = m_cachePath + m_modelName + "cache_info.nncache";
641 if (access(cacheInfoPath.c_str(), 0) != 0) {
642 ret = GenCacheBuild(preparedModel);
643 if (ret != OH_NN_SUCCESS) {
644 LOGE("Fail to build in generating cache mode.");
645 return ret;
646 }
647
648 m_isBuild = true;
649 return OH_NN_SUCCESS;
650 }
651
652 ModelCacheInfo cacheInfo;
653 ret = CheckCacheInfo(cacheInfo, cacheInfoPath);
654 if (ret != OH_NN_SUCCESS) {
655 return ret;
656 }
657
658 if (m_version > cacheInfo.version) {
659 ret = ReGenCacheBuild(cacheInfo.fileNumber, preparedModel);
660 if (ret != OH_NN_SUCCESS) {
661 return ret;
662 }
663
664 m_isBuild = true;
665 return OH_NN_SUCCESS;
666 }
667
668 if (m_version < cacheInfo.version) {
669 LOGE("[Compilation] The current version is lower than the cache files, please set a higher version.");
670 return OH_NN_OPERATION_FORBIDDEN;
671 }
672
673 ret = LoadCacheBuild(preparedModel, cacheInfo);
674 if (ret != OH_NN_SUCCESS) {
675 // recompile the model online and update the cache when failing to build cache model
676 ret = ReGenCacheBuild(cacheInfo.fileNumber, preparedModel);
677 if (ret != OH_NN_SUCCESS) {
678 LOGE("[Compilation] Failed to re-generate and build cache model.");
679 return ret;
680 }
681 }
682
683 m_isBuild = true;
684
685 return OH_NN_SUCCESS;
686 }
687
InnerBuild()688 OH_NN_ReturnCode Compilation::InnerBuild()
689 {
690 OH_NN_ReturnCode ret;
691 std::shared_ptr<PreparedModel> preparedModel;
692
693 // Prepare from offline model.
694 bool isOfflineModel {false};
695 ret = IsOfflineModel(isOfflineModel);
696 if (ret != OH_NN_SUCCESS) {
697 LOGE("[Compilation] Failed when identifying the offline model.");
698 return ret;
699 }
700
701 if (isOfflineModel) {
702 ret = BuildOfflineModel(preparedModel);
703 if (ret != OH_NN_SUCCESS) {
704 LOGE("[Compilation] Failed to build offline model.");
705 return ret;
706 }
707
708 m_isBuild = true;
709 return OH_NN_SUCCESS;
710 }
711
712 if (m_cachePath.empty()) {
713 ret = NormalBuild(preparedModel);
714 if (ret != OH_NN_SUCCESS) {
715 LOGE("Fail to normally build.");
716 return ret;
717 }
718
719 m_isBuild = true;
720 return OH_NN_SUCCESS;
721 }
722
723 ret = BuildCacheModel(preparedModel);
724 if (ret != OH_NN_SUCCESS) {
725 LOGE("Fail to build cache model.");
726 return ret;
727 }
728
729 return OH_NN_SUCCESS;
730 }
731
Build()732 OH_NN_ReturnCode Compilation::Build()
733 {
734 NNRT_TRACE_NAME("Compilation");
735 if (m_isBuild) {
736 LOGE("[Compilation] Cannot enable float16 after compilation finish.");
737 return OH_NN_OPERATION_FORBIDDEN;
738 }
739
740 if (m_device == nullptr) {
741 LOGE("The parameter of m_device is nullptr, please call SetDevice function before build model.");
742 return OH_NN_OPERATION_FORBIDDEN;
743 }
744
745 OH_NN_ReturnCode ret = InnerBuild();
746 if (ret != OH_NN_SUCCESS) {
747 return ret;
748 }
749
750 return OH_NN_SUCCESS;
751 }
752
GetExecutionPlan() const753 std::shared_ptr<ExecutionPlan> Compilation::GetExecutionPlan() const
754 {
755 return m_executionPlan;
756 }
757
GetInputTensors() const758 std::vector<std::shared_ptr<NNTensor>> Compilation::GetInputTensors() const
759 {
760 return m_inputTensors;
761 }
762
GetOutputTensors() const763 std::vector<std::shared_ptr<NNTensor>> Compilation::GetOutputTensors() const
764 {
765 return m_outputTensors;
766 }
767
IsBuild() const768 bool Compilation::IsBuild() const
769 {
770 return m_isBuild;
771 }
772
IsDynamicShape() const773 bool Compilation::IsDynamicShape() const
774 {
775 size_t inputTensorsSize = m_inputTensors.size();
776 for (size_t i = 0; i < inputTensorsSize; ++i) {
777 if (m_inputTensors[i]->IsDynamicShape()) {
778 return true;
779 }
780 }
781 return false;
782 }
783
IsOfflineModel(bool & isOfflineModel) const784 OH_NN_ReturnCode Compilation::IsOfflineModel(bool& isOfflineModel) const
785 {
786 isOfflineModel = false; // Initialize the returned value
787 if ((m_liteGraph == nullptr) && (m_metaGraph == nullptr)) {
788 LOGE("[Compilation] LiteGraph and metaGraph are empty when identifying the offline model.");
789 return OH_NN_NULL_PTR;
790 }
791
792 if ((m_liteGraph != nullptr) && (m_metaGraph != nullptr)) {
793 LOGE("[Compilation] LiteGraph and metaGraph are not empty when identifying the offline model.");
794 return OH_NN_INVALID_PARAMETER;
795 }
796
797 if (m_metaGraph != nullptr) {
798 isOfflineModel = false;
799 return OH_NN_SUCCESS;
800 }
801
802 if (m_liteGraph->all_nodes_.size() == 0) {
803 LOGE("[Compilation] Find empty node in the model.");
804 return OH_NN_INVALID_PARAMETER;
805 }
806
807 // If the model consists of more than 1 node, it will not be considered as offline model.
808 if (m_liteGraph->all_nodes_.size() > 1) {
809 isOfflineModel = false;
810 return OH_NN_SUCCESS;
811 }
812
813 const mindspore::lite::LiteGraph::Node* pNode = m_liteGraph->all_nodes_[0];
814 if (pNode == nullptr) {
815 LOGE("[Compilation] Find invalid node in the model.");
816 return OH_NN_NULL_PTR;
817 }
818
819 const mindspore::lite::NodeType& nodeType = mindspore::lite::MindIR_Primitive_GetType(pNode->primitive_);
820 if (nodeType == mindspore::lite::NodeType::NODE_TYPE_CUSTOM) {
821 isOfflineModel = true;
822 }
823
824 return OH_NN_SUCCESS;
825 }
826
BuildOfflineModel(std::shared_ptr<PreparedModel> & preparedModel)827 OH_NN_ReturnCode Compilation::BuildOfflineModel(std::shared_ptr<PreparedModel>& preparedModel)
828 {
829 ModelConfig config {m_enableFp16, m_performance, m_priority};
830 OH_NN_ReturnCode ret = m_device->PrepareOfflineModel(m_liteGraph, config, preparedModel);
831 if (ret != OH_NN_SUCCESS) {
832 LOGE("[Compilation] Preparing model failed when building from offline model.");
833 return ret;
834 }
835
836 m_executionPlan = CreateSharedPtr<ExecutionPlan>(preparedModel, m_device);
837 if (m_executionPlan == nullptr) {
838 LOGE("[Compilation] Failed to create ExecutionPlan when building from offline model.");
839 return OH_NN_MEMORY_ERROR;
840 }
841
842 return OH_NN_SUCCESS;
843 }
844 } // namespace NeuralNetworkRuntime
845 } // namespace OHOS