1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "cxx_api/model/acl/acl_model_options.h"
17 #include <set>
18 #include <memory>
19 #include "utils/log_adapter.h"
20 #include "ge/ge_api_types.h"
21 #include "cxx_api/acl_utils.h"
22 #include "transform/symbol/acl_base_symbol.h"
23 #include "transform/symbol/symbol_utils.h"
24
25 namespace mindspore {
26 static const std::map<enum DataType, std::string> kSupportedDtypeOptionMap = {{DataType::kNumberTypeFloat16, "FP16"},
27 {DataType::kNumberTypeFloat32, "FP32"},
28 {DataType::kNumberTypeUInt8, "UINT8"}};
29
AclModelOptions(const std::shared_ptr<Context> & context)30 AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
31 if (context == nullptr) {
32 return;
33 }
34 auto &device_infos = context->MutableDeviceInfo();
35 if (device_infos.size() != 1) {
36 return;
37 }
38 auto ascend_info = device_infos[0]->Cast<AscendDeviceInfo>();
39 if (ascend_info == nullptr) {
40 return;
41 }
42
43 insert_op_cfg_path_ = ascend_info->GetInsertOpConfigPath();
44 input_format_ = ascend_info->GetInputFormat();
45 input_shape_map_ = ascend_info->GetInputShapeMap();
46 auto out_type = ascend_info->GetOutputType();
47 auto iter = kSupportedDtypeOptionMap.find(out_type);
48 if (out_type == DataType::kTypeUnknown) {
49 // do nothing
50 } else if (iter == kSupportedDtypeOptionMap.end()) {
51 MS_LOG(INFO) << "Unsupported output type " << out_type << ", use FP32 as default.";
52 } else {
53 output_type_ = iter->second;
54 }
55 dynamic_batch_size_ = ascend_info->GetDynamicBatchSize();
56 dynamic_image_size_ = ascend_info->GetDynamicImageSize();
57 precision_mode_ = TransforPrecisionToAcl(ascend_info->GetPrecisionMode());
58 op_select_impl_mode_ = ascend_info->GetOpSelectImplMode();
59 fusion_switch_cfg_path_ = ascend_info->GetFusionSwitchConfigPath();
60 device_id_ = ascend_info->GetDeviceID();
61 buffer_optimize_mode_ = ascend_info->GetBufferOptimizeMode();
62 if (!ascend_info->GetInputShape().empty()) {
63 input_shape_ = ascend_info->GetInputShape();
64 }
65 const char *soc_name = CALL_ASCEND_API(aclrtGetSocName);
66 if (soc_name == nullptr) {
67 MS_LOG(WARNING) << "Get soc version failed.";
68 return;
69 }
70 soc_version_ = soc_name;
71 }
72
GetSocName()73 std::string AclModelOptions::GetSocName() {
74 const char *soc_name = CALL_ASCEND_API(aclrtGetSocName);
75 if (soc_name == nullptr) {
76 MS_LOG(WARNING) << "Get soc version failed.";
77 return "";
78 }
79 return soc_name;
80 }
81
RenameInput(const std::vector<std::string> & input_names)82 void AclModelOptions::RenameInput(const std::vector<std::string> &input_names) {
83 if (input_names.size() != input_shape_map_.size()) {
84 MS_LOG(INFO) << "Inputs count not match";
85 return;
86 }
87 input_shape_ = "";
88 for (size_t i = 0; i < input_shape_map_.size(); i++) {
89 if (input_shape_map_.find(i) == input_shape_map_.end()) {
90 MS_LOG(WARNING) << "Not find the key: " << i;
91 return;
92 }
93 std::string s;
94 for (size_t j = 0; j < input_shape_map_[i].size(); j++) {
95 s += std::to_string(input_shape_map_[i][j]) + ",";
96 }
97 input_shape_ += input_names[i] + ":" + s.substr(0, s.size() - 1) + ";";
98 }
99 input_shape_ = input_shape_.substr(0, input_shape_.size() - 1);
100 MS_LOG(INFO) << "input name is " << input_shape_;
101 }
102
GenAclOptions() const103 std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> AclModelOptions::GenAclOptions()
104 const {
105 const std::map<std::string const *, std::string> init_options_map = {
106 {&op_select_impl_mode_, ge::ir_option::OP_SELECT_IMPL_MODE},
107 {&soc_version_, ge::ir_option::SOC_VERSION},
108 {&fusion_switch_cfg_path_, ge::ir_option::FUSION_SWITCH_FILE},
109 {&buffer_optimize_mode_, ge::ir_option::BUFFER_OPTIMIZE}};
110
111 const std::map<std::string const *, std::string> build_options_map = {
112 {&insert_op_cfg_path_, ge::ir_option::INSERT_OP_FILE},
113 {&input_format_, ge::ir_option::INPUT_FORMAT},
114 {&input_shape_, ge::ir_option::INPUT_SHAPE},
115 {&output_type_, ge::ir_option::OUTPUT_TYPE},
116 {&precision_mode_, ge::ir_option::PRECISION_MODE},
117 {&dynamic_batch_size_, ge::ir_option::DYNAMIC_BATCH_SIZE},
118 {&dynamic_image_size_, ge::ir_option::DYNAMIC_IMAGE_SIZE}};
119
120 const std::set<std::string> first_graph_options = {
121 ge::ir_option::INSERT_OP_FILE,
122 ge::ir_option::INPUT_FORMAT,
123 ge::ir_option::INPUT_SHAPE,
124 };
125
126 const std::set<std::string> multi_graph_unsupported_options = {ge::ir_option::OUTPUT_TYPE};
127
128 std::map<std::string, std::string> init_options;
129 std::map<std::string, std::string> build_options;
130 for (auto [ms_option, acl_option_key] : init_options_map) {
131 if (ms_option == nullptr || ms_option->empty()) {
132 continue;
133 }
134 MS_LOG(INFO) << "Option " << acl_option_key << " : " << *ms_option;
135 (void)init_options.emplace(acl_option_key, *ms_option);
136 }
137
138 for (auto [ms_option, acl_option_key] : build_options_map) {
139 if (ms_option == nullptr || ms_option->empty()) {
140 continue;
141 }
142 MS_LOG(INFO) << "Option " << acl_option_key << " : " << *ms_option;
143 (void)build_options.emplace(acl_option_key, *ms_option);
144 }
145
146 // init by config file param
147 for (auto item : init_options_map_) {
148 MS_LOG(INFO) << "Option " << item.first << " : " << item.second;
149 if (item.first == ge::ir_option::SOC_VERSION) {
150 auto soc_version = item.second;
151 if (soc_version != soc_version_) {
152 MS_LOG(WARNING) << "ge.socVersion: " << soc_version
153 << " is different with this machine core type: " << soc_version_;
154 }
155 init_options[item.first] = item.second;
156 continue;
157 }
158 if (init_options.find(item.first) != init_options.end()) {
159 MS_LOG(WARNING) << "the parameters[" << item.first
160 << "] have been set through the API and do not need to be repeated.";
161 continue;
162 }
163 (void)init_options.emplace(item.first, item.second);
164 }
165
166 for (auto item : build_options_map_) {
167 MS_LOG(INFO) << "Option " << item.first << " : " << item.second;
168 if (build_options.find(item.first) != build_options.end()) {
169 MS_LOG(WARNING) << "the parameters[" << item.first
170 << "] have been set through the API and do not need to be repeated.";
171 continue;
172 }
173 (void)build_options.emplace(item.first, item.second);
174 }
175
176 // first_graph_flag has value means being multi graph mode
177 if (first_graph_flag_.has_value()) {
178 for (const auto &option : multi_graph_unsupported_options) {
179 (void)build_options.erase(option);
180 }
181 // non-input graph
182 if (!first_graph_flag_) {
183 for (const auto &option : first_graph_options) {
184 (void)build_options.erase(option);
185 }
186 }
187 }
188
189 return {init_options, build_options};
190 }
191
GenAoeOptions(std::vector<std::string> * aoe_modes)192 std::string AclModelOptions::GenAoeOptions(std::vector<std::string> *aoe_modes) {
193 std::string res;
194 std::map<std::string, std::string> aoe_options = aoe_global_options_map_;
195 aoe_options.insert(aoe_tuning_options_map_.begin(), aoe_tuning_options_map_.end());
196 if (aoe_options.find("job_type") != aoe_options.end()) {
197 aoe_modes->clear();
198 (void)aoe_modes->emplace_back(aoe_options.at("job_type"));
199 }
200 if (aoe_modes->empty()) {
201 MS_LOG(WARNING) << "Aoe mode are invalid "
202 << "; It should be 'subgraph tuning, operator tuning' in aoe_mode, or '1, 2' in job_type";
203 }
204
205 for (auto item : aoe_options) {
206 if (item.first == "job_type" || item.first == "framework" || item.first == "model") {
207 continue;
208 }
209 if (item.second.empty()) {
210 res += " --" + item.first;
211 } else {
212 res += " --" + item.first + "=" + item.second;
213 }
214 }
215
216 MS_LOG(INFO) << "aoe_options: " << res;
217 return res;
218 }
219
GenAclOptionsKey() const220 std::string AclModelOptions::GenAclOptionsKey() const {
221 auto [init_options, build_options] = GenAclOptions();
222 std::string key_str;
223 for (auto &[key, value] : init_options) {
224 key_str += key + "^" + value + "^^";
225 }
226 for (auto &[key, value] : build_options) {
227 key_str += key + "^" + value + "^^";
228 }
229 return key_str;
230 }
231 } // namespace mindspore
232