1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "extension_options_parser.h"
18 #include "stdlib.h"
19 #include <map>
20
21 namespace mindspore::lite::nnrt {
22 namespace {
23 const std::map<std::string, mindspore::lite::HiAI_BandMode> kBandModeMap = {
24 {"HIAI_BANDMODE_UNSET", mindspore::lite::HIAI_BANDMODE_UNSET},
25 {"HIAI_BANDMODE_LOW", mindspore::lite::HIAI_BANDMODE_LOW},
26 {"HIAI_BANDMODE_NORMAL", mindspore::lite::HIAI_BANDMODE_NORMAL},
27 {"HIAI_BANDMODE_HIGH", mindspore::lite::HIAI_BANDMODE_HIGH},
28 };
29 const std::string kCachePath = "CachePath";
30 const std::string kCacheVersion = "CacheVersion";
31 const std::string kBandMode = "BandMode";
32 const std::string kQuantBuffer = "QuantBuffer";
33 const std::string kQuantConfigData = "QuantConfigData";
34 const std::string kModelName = "ModelName";
35 } // namespace
36
Parse(const std::vector<Extension> & extensions,ExtensionOptions * param)37 int ExtensionOptionsParser::Parse(const std::vector<Extension> &extensions, ExtensionOptions *param) {
38 MS_CHECK_TRUE_RET(param != nullptr, RET_ERROR);
39
40 DoParseCachePath(extensions, ¶m->cache_path_);
41 DoParseCacheVersion(extensions, ¶m->cache_version_);
42 DoParseBondMode(extensions, ¶m->band_mode);
43 DoParseQuantConfig(extensions, ¶m->quant_config, ¶m->quant_config_size, ¶m->is_optional_quant_setted);
44 DoParseModelName(extensions, ¶m->model_name);
45 return RET_OK;
46 }
47
DoParseCachePath(const std::vector<Extension> & extensions,std::string * cache_path)48 void ExtensionOptionsParser::DoParseCachePath(const std::vector<Extension> &extensions, std::string *cache_path) {
49 MS_CHECK_TRUE_RET_VOID(cache_path != nullptr);
50 auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
51 return extension.name == kCachePath;
52 });
53 if (iter_config != extensions.end()) {
54 *cache_path = std::string(iter_config->value.begin(), iter_config->value.end());
55 }
56 }
57
DoParseCacheVersion(const std::vector<Extension> & extensions,uint32_t * cache_version)58 void ExtensionOptionsParser::DoParseCacheVersion(const std::vector<Extension> &extensions, uint32_t *cache_version) {
59 MS_CHECK_TRUE_RET_VOID(cache_version != nullptr);
60 auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
61 return extension.name == kCacheVersion;
62 });
63 if (iter_config != extensions.end()) {
64 std::string version_str = std::string(iter_config->value.begin(), iter_config->value.end());
65 *cache_version = static_cast<uint32_t>(std::atol(version_str.c_str()));
66 }
67 }
68
DoParseBondMode(const std::vector<Extension> & extensions,mindspore::lite::HiAI_BandMode * band_mode)69 void ExtensionOptionsParser::DoParseBondMode(const std::vector<Extension> &extensions, mindspore::lite::HiAI_BandMode *band_mode) {
70 MS_CHECK_TRUE_RET_VOID(band_mode != nullptr);
71 auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
72 return extension.name == kBandMode;
73 });
74 if (iter_config != extensions.end()) {
75 auto iter = kBandModeMap.find(std::string(iter_config->value.begin(), iter_config->value.end()));
76 if (iter != kBandModeMap.end()) {
77 *band_mode = iter->second;
78 }
79 }
80 }
81
DoParseQuantConfig(const std::vector<Extension> & extensions,void ** quant_config,size_t * num,bool * quant_setted)82 void ExtensionOptionsParser::DoParseQuantConfig(const std::vector<Extension> &extensions,
83 void **quant_config, size_t *num, bool *quant_setted) {
84 MS_CHECK_TRUE_RET_VOID(quant_config != nullptr);
85 MS_CHECK_TRUE_RET_VOID(num != nullptr);
86 auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
87 return extension.name == kQuantBuffer || extension.name == kQuantConfigData;
88 });
89 if (iter_config != extensions.end()) {
90 *quant_config = static_cast<void *>(const_cast<uint8_t *>(iter_config->value.data()));
91 *num = iter_config->value.size();
92 *quant_setted = true;
93 }
94 }
95
DoParseModelName(const std::vector<Extension> & extensions,std::string * model_name)96 void ExtensionOptionsParser::DoParseModelName(const std::vector<Extension> &extensions, std::string *model_name) {
97 MS_CHECK_TRUE_RET_VOID(model_name != nullptr);
98 auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
99 return extension.name == kModelName;
100 });
101 if (iter_config != extensions.end()) {
102 *model_name = std::string(iter_config->value.begin(), iter_config->value.end());
103 }
104 }
105 } // mindspore::lite::nnrt