• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "cxx_api/model/acl/acl_model_options.h"
17 #include <set>
18 #include <memory>
19 #include "utils/log_adapter.h"
20 #include "external/ge/ge_api_types.h"
21 #include "acl/acl_base.h"
22 
23 namespace mindspore {
24 static const std::map<enum DataType, std::string> kSupportedDtypeOptionMap = {{DataType::kNumberTypeFloat16, "FP16"},
25                                                                               {DataType::kNumberTypeFloat32, "FP32"},
26                                                                               {DataType::kNumberTypeUInt8, "UINT8"}};
27 
AclModelOptions(const std::shared_ptr<Context> & context)28 AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
29   if (context == nullptr) {
30     return;
31   }
32   auto &device_infos = context->MutableDeviceInfo();
33   if (device_infos.size() != 1) {
34     return;
35   }
36   auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>();
37   if (ascend310_info == nullptr) {
38     return;
39   }
40 
41   insert_op_cfg_path_ = ascend310_info->GetInsertOpConfigPath();
42   input_format_ = ascend310_info->GetInputFormat();
43   input_shape_map_ = ascend310_info->GetInputShapeMap();
44   auto out_type = ascend310_info->GetOutputType();
45   auto iter = kSupportedDtypeOptionMap.find(out_type);
46   if (out_type == DataType::kTypeUnknown) {
47     // do nothing
48   } else if (iter == kSupportedDtypeOptionMap.end()) {
49     MS_LOG(WARNING) << "Unsupported output type " << out_type << ", use FP32 as default.";
50   } else {
51     output_type_ = iter->second;
52   }
53   dynamic_batch_size_ = ascend310_info->GetDynamicBatchSize();
54   precision_mode_ = ascend310_info->GetPrecisionMode();
55   op_select_impl_mode_ = ascend310_info->GetOpSelectImplMode();
56   fusion_switch_cfg_path_ = ascend310_info->GetFusionSwitchConfigPath();
57   device_id_ = ascend310_info->GetDeviceID();
58   buffer_optimize_mode_ = ascend310_info->GetBufferOptimizeMode();
59   const char *soc_name = aclrtGetSocName();
60   if (soc_name == nullptr) {
61     MS_LOG(WARNING) << "Get soc version failed.";
62     return;
63   }
64   soc_version_ = soc_name;
65 }
66 
RenameInput(const std::vector<std::string> & input_names)67 void AclModelOptions::RenameInput(const std::vector<std::string> &input_names) {
68   if (input_names.size() != input_shape_map_.size()) {
69     MS_LOG(INFO) << "Inputs count not match";
70     return;
71   }
72   input_shape_ = "";
73   for (size_t i = 0; i < input_shape_map_.size(); i++) {
74     std::string s;
75     for (size_t j = 0; j < input_shape_map_[i].size(); j++) {
76       s += std::to_string(input_shape_map_[i][j]) + ",";
77     }
78     input_shape_ += input_names[i] + ":" + s.substr(0, s.size() - 1) + ";";
79   }
80   input_shape_ = input_shape_.substr(0, input_shape_.size() - 1);
81   MS_LOG(INFO) << "input name is " << input_shape_;
82 }
83 
GenAclOptions() const84 std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> AclModelOptions::GenAclOptions()
85   const {
86   const std::map<std::string const *, std::string> init_options_map = {
87     {&op_select_impl_mode_, ge::ir_option::OP_SELECT_IMPL_MODE},
88     {&soc_version_, ge::ir_option::SOC_VERSION},
89     {&fusion_switch_cfg_path_, ge::ir_option::FUSION_SWITCH_FILE},
90     {&buffer_optimize_mode_, ge::ir_option::BUFFER_OPTIMIZE}};
91 
92   const std::map<std::string const *, std::string> build_options_map = {
93     {&insert_op_cfg_path_, ge::ir_option::INSERT_OP_FILE},
94     {&input_format_, ge::ir_option::INPUT_FORMAT},
95     {&input_shape_, ge::ir_option::INPUT_SHAPE},
96     {&output_type_, ge::ir_option::OUTPUT_TYPE},
97     {&precision_mode_, ge::ir_option::PRECISION_MODE},
98     {&dynamic_batch_size_, ge::ir_option::DYNAMIC_BATCH_SIZE},
99     {&dynamic_image_size_, ge::ir_option::DYNAMIC_IMAGE_SIZE}};
100 
101   const std::set<std::string> first_graph_options = {
102     ge::ir_option::INSERT_OP_FILE,
103     ge::ir_option::INPUT_FORMAT,
104     ge::ir_option::INPUT_SHAPE,
105   };
106 
107   const std::set<std::string> multi_graph_unsupported_options = {ge::ir_option::OUTPUT_TYPE};
108 
109   std::map<std::string, std::string> init_options;
110   std::map<std::string, std::string> build_options;
111   for (auto [ms_option, acl_option_key] : init_options_map) {
112     if (ms_option == nullptr || ms_option->empty()) {
113       continue;
114     }
115     MS_LOG(INFO) << "Option " << acl_option_key << " : " << *ms_option;
116     init_options.emplace(acl_option_key, *ms_option);
117   }
118 
119   for (auto [ms_option, acl_option_key] : build_options_map) {
120     if (ms_option == nullptr || ms_option->empty()) {
121       continue;
122     }
123     MS_LOG(INFO) << "Option " << acl_option_key << " : " << *ms_option;
124     build_options.emplace(acl_option_key, *ms_option);
125   }
126 
127   // first_graph_flag has value means being multi graph mode
128   if (first_graph_flag_.has_value()) {
129     for (const auto &option : multi_graph_unsupported_options) {
130       build_options.erase(option);
131     }
132     // non-input graph
133     if (!first_graph_flag_) {
134       for (const auto &option : first_graph_options) {
135         build_options.erase(option);
136       }
137     }
138   }
139 
140   return {init_options, build_options};
141 }
142 
GenAclOptionsKey() const143 std::string AclModelOptions::GenAclOptionsKey() const {
144   auto [init_options, build_options] = GenAclOptions();
145   std::string key_str;
146   for (auto &[key, value] : init_options) {
147     key_str += key + "^" + value + "^^";
148   }
149   for (auto &[key, value] : build_options) {
150     key_str += key + "^" + value + "^^";
151   }
152   return key_str;
153 }
154 }  // namespace mindspore
155