• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "extension_options_parser.h"
18 #include "stdlib.h"
19 #include <map>
20 
21 namespace mindspore::lite::nnrt {
22 namespace {
23 const std::map<std::string, mindspore::lite::HiAI_BandMode> kBandModeMap = {
24     {"HIAI_BANDMODE_UNSET", mindspore::lite::HIAI_BANDMODE_UNSET},
25     {"HIAI_BANDMODE_LOW", mindspore::lite::HIAI_BANDMODE_LOW},
26     {"HIAI_BANDMODE_NORMAL", mindspore::lite::HIAI_BANDMODE_NORMAL},
27     {"HIAI_BANDMODE_HIGH", mindspore::lite::HIAI_BANDMODE_HIGH},
28 };
29 const std::string kCachePath = "CachePath";
30 const std::string kCacheVersion = "CacheVersion";
31 const std::string kBandMode = "BandMode";
32 const std::string kQuantBuffer = "QuantBuffer";
33 const std::string kQuantConfigData = "QuantConfigData";
34 const std::string kModelName = "ModelName";
35 }  // namespace
36 
Parse(const std::vector<Extension> & extensions,ExtensionOptions * param)37 int ExtensionOptionsParser::Parse(const std::vector<Extension> &extensions, ExtensionOptions *param) {
38   MS_CHECK_TRUE_RET(param != nullptr, RET_ERROR);
39 
40   DoParseCachePath(extensions, &param->cache_path_);
41   DoParseCacheVersion(extensions, &param->cache_version_);
42   DoParseBondMode(extensions, &param->band_mode);
43   DoParseQuantConfig(extensions, &param->quant_config, &param->quant_config_size, &param->is_optional_quant_setted);
44   DoParseModelName(extensions, &param->model_name);
45   return RET_OK;
46 }
47 
DoParseCachePath(const std::vector<Extension> & extensions,std::string * cache_path)48 void ExtensionOptionsParser::DoParseCachePath(const std::vector<Extension> &extensions, std::string *cache_path) {
49   MS_CHECK_TRUE_RET_VOID(cache_path != nullptr);
50   auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
51     return extension.name == kCachePath;
52   });
53   if (iter_config != extensions.end()) {
54     *cache_path = std::string(iter_config->value.begin(), iter_config->value.end());
55   }
56 }
57 
DoParseCacheVersion(const std::vector<Extension> & extensions,uint32_t * cache_version)58 void ExtensionOptionsParser::DoParseCacheVersion(const std::vector<Extension> &extensions, uint32_t *cache_version) {
59   MS_CHECK_TRUE_RET_VOID(cache_version != nullptr);
60   auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
61     return extension.name == kCacheVersion;
62   });
63   if (iter_config != extensions.end()) {
64     std::string version_str = std::string(iter_config->value.begin(), iter_config->value.end());
65     *cache_version = static_cast<uint32_t>(std::atol(version_str.c_str()));
66   }
67 }
68 
DoParseBondMode(const std::vector<Extension> & extensions,mindspore::lite::HiAI_BandMode * band_mode)69 void ExtensionOptionsParser::DoParseBondMode(const std::vector<Extension> &extensions, mindspore::lite::HiAI_BandMode *band_mode) {
70   MS_CHECK_TRUE_RET_VOID(band_mode != nullptr);
71   auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
72     return extension.name == kBandMode;
73   });
74   if (iter_config != extensions.end()) {
75     auto iter = kBandModeMap.find(std::string(iter_config->value.begin(), iter_config->value.end()));
76     if (iter != kBandModeMap.end()) {
77       *band_mode = iter->second;
78     }
79   }
80 }
81 
DoParseQuantConfig(const std::vector<Extension> & extensions,void ** quant_config,size_t * num,bool * quant_setted)82 void ExtensionOptionsParser::DoParseQuantConfig(const std::vector<Extension> &extensions,
83                                                 void **quant_config, size_t *num, bool *quant_setted) {
84   MS_CHECK_TRUE_RET_VOID(quant_config != nullptr);
85   MS_CHECK_TRUE_RET_VOID(num != nullptr);
86   auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
87     return extension.name == kQuantBuffer || extension.name == kQuantConfigData;
88   });
89   if (iter_config != extensions.end()) {
90     *quant_config = static_cast<void *>(const_cast<uint8_t *>(iter_config->value.data()));
91     *num = iter_config->value.size();
92     *quant_setted = true;
93   }
94 }
95 
DoParseModelName(const std::vector<Extension> & extensions,std::string * model_name)96 void ExtensionOptionsParser::DoParseModelName(const std::vector<Extension> &extensions, std::string *model_name) {
97   MS_CHECK_TRUE_RET_VOID(model_name != nullptr);
98   auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
99     return extension.name == kModelName;
100   });
101   if (iter_config != extensions.end()) {
102     *model_name = std::string(iter_config->value.begin(), iter_config->value.end());
103   }
104 }
105 }  // mindspore::lite::nnrt