1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "transform/acl_ir/ge_adapter_info.h"
18 #include <algorithm>
19 #include <limits>
20 #include "include/transform/graph_ir/utils.h"
21 #include "transform/graph_ir/transform_util.h"
22 #include "graph/operator_factory.h"
23
24 namespace mindspore {
25 namespace transform {
InitOpType()26 void GeAdapterInfo::InitOpType() { info_.op_type = adapter_->getOpType(); }
27
InitAclInputsAndOutputs()28 void GeAdapterInfo::InitAclInputsAndOutputs() {
29 InitParametersMap(adapter_->getInputMap(), adapter_->getDynInputMap(), true);
30 InitParametersMap(adapter_->getOutputMap(), adapter_->getDynOutputMap(), false);
31 }
32
InitRefMap()33 void GeAdapterInfo::InitRefMap() {
34 for (const auto &[output_index, output_param_info] : info_.output_idx_ms2ge) {
35 for (const auto &[input_index, input_param_info] : info_.input_idx_ms2ge) {
36 if (output_param_info.name == input_param_info.name) {
37 (void)info_.ref_map_.emplace(IntToSize(output_index), IntToSize(input_index));
38 break;
39 }
40 }
41 }
42 }
43
44 template <typename ParamMap, typename DynParamMap>
InitParametersMap(const ParamMap & params,const DynParamMap & dyn_params,bool is_input)45 void GeAdapterInfo::InitParametersMap(const ParamMap ¶ms, const DynParamMap &dyn_params, bool is_input) {
46 auto &mapping_flags = is_input ? info_.input_mapping_flags : info_.output_mapping_flags;
47 auto &idx_ms2ge = is_input ? info_.input_idx_ms2ge : info_.output_idx_ms2ge;
48 auto &idx_ge2ms = is_input ? info_.input_idx_ge2ms : info_.output_idx_ge2ms;
49
50 if (params.empty() && dyn_params.empty()) {
51 mapping_flags |= GeTensorInfo::kEmptyParam;
52 return;
53 }
54
55 // calculate index of dynamic input/output
56 size_t ge_dynmaic_idx = std::numeric_limits<size_t>::max();
57 if (!dyn_params.empty()) {
58 if (dyn_params.size() > 1) {
59 MS_LOG(DEBUG) << "Op " << adapter_->getOpType() << " has " << dyn_params.size() << " dynamic "
60 << (is_input ? "inputs" : "outputs");
61 mapping_flags |= GeTensorInfo::kMultiDynParam;
62 } else {
63 mapping_flags |= GeTensorInfo::kDynamicParam;
64 ge_dynmaic_idx = dyn_params.cbegin()->second.index;
65 }
66 }
67
68 auto get_ms_idx = [is_input](int index) {
69 // for anf cnode, the 1st input is primitive name, so for input the real input index is `index - 1`
70 return is_input ? index - 1 : index;
71 };
72
73 // process required/optional inputs or required outputs
74 for (const auto &[k, v] : params) {
75 int ms_idx = get_ms_idx(k);
76 uint32_t ge_idx = static_cast<uint32_t>(v.index);
77 // MindSpore Index --> GE Info
78 if constexpr (std::is_same<std::remove_cv_t<decltype(v)>, InputDesc>::value) {
79 idx_ms2ge[ms_idx] = Ms2GeParamInfo{
80 ge_idx, v.name, v.type == InputDesc::OPTIONAL ? Ms2GeParamInfo::OPTIONAL : Ms2GeParamInfo::REQUIRED,
81 ge_idx > ge_dynmaic_idx};
82 } else {
83 idx_ms2ge[ms_idx] = Ms2GeParamInfo{ge_idx, v.name, Ms2GeParamInfo::REQUIRED, ge_idx > ge_dynmaic_idx};
84 }
85
86 // input/output: GE(GraphEngine) Index --> MindSpore Index
87 idx_ge2ms[ge_idx] = ms_idx;
88 if (is_input) {
89 max_input_ms_proto_idx_ = std::max(max_input_ms_proto_idx_, ms_idx);
90 }
91 }
92
93 // process dynamic inputs/outputs
94 for (const auto &[k, v] : dyn_params) {
95 int ms_idx = get_ms_idx(k);
96 uint32_t ge_idx = static_cast<uint32_t>(v.index);
97 // MindSpore Index --> GE Info
98 idx_ms2ge[ms_idx] = Ms2GeParamInfo{ge_idx, v.name, Ms2GeParamInfo::DYNAMIC, ge_idx > ge_dynmaic_idx};
99 // input/output: GE(GraphEngine) Index --> MindSpore Index
100 idx_ge2ms[ge_idx] = ms_idx;
101 if (is_input) {
102 max_input_ms_proto_idx_ = std::max(max_input_ms_proto_idx_, ms_idx);
103 }
104 }
105 }
106
InitInputSupportedDataType()107 void GeAdapterInfo::InitInputSupportedDataType() {
108 info_.input_supported_dtypes.clear();
109 for (const auto &[k, v] : adapter_->getInputMap()) {
110 (void)info_.input_supported_dtypes.emplace(k - 1, v.supported_dtypes);
111 }
112 for (const auto &[k, v] : adapter_->getDynInputMap()) {
113 (void)info_.input_supported_dtypes.emplace(k - 1, v.supported_dtypes);
114 }
115 }
116
InitOutputSupportedDataType()117 void GeAdapterInfo::InitOutputSupportedDataType() {
118 info_.output_supported_dtypes.clear();
119 for (const auto &[k, v] : adapter_->getOutputMap()) {
120 (void)info_.output_supported_dtypes.emplace(k, v.supported_dtypes);
121 }
122 for (const auto &[k, v] : adapter_->getDynOutputMap()) {
123 (void)info_.output_supported_dtypes.emplace(k, v.supported_dtypes);
124 }
125 }
126
GetGeAttrValueByMsAttrValue(const std::string & attr_name,ValuePtr * ms_value)127 void GeAdapterInfo::GetGeAttrValueByMsAttrValue(const std::string &attr_name, ValuePtr *ms_value) {
128 MS_EXCEPTION_IF_NULL(ms_value);
129 // class Value is a abstract class
130 auto iter = get_attr_cache_.find({attr_name, *ms_value});
131 if (iter != get_attr_cache_.end()) {
132 *ms_value = iter->second;
133 return;
134 }
135
136 int ret = 0;
137 auto old_value = *ms_value;
138 ret = adapter_->getAttr(attr_name, ms_value);
139 if (ret != 0) {
140 MS_LOG(EXCEPTION) << "failed to get attr:" << attr_name << " for primitive " << info_.op_type;
141 }
142 get_attr_cache_[{attr_name, old_value}] = *ms_value;
143 }
144
GetGeAttrValueByMsInputValue(const uint32_t & input_idx,ValuePtr * ms_value)145 void GeAdapterInfo::GetGeAttrValueByMsInputValue(const uint32_t &input_idx, ValuePtr *ms_value) {
146 MS_EXCEPTION_IF_NULL(ms_value);
147 // class Value is a abstract class
148 auto iter = get_input_attr_cache_.find({input_idx, *ms_value});
149 if (iter != get_input_attr_cache_.end()) {
150 *ms_value = iter->second;
151 return;
152 }
153
154 int ret = 0;
155 auto old_value = *ms_value;
156 ret = adapter_->getAttr(input_idx, ms_value);
157 if (ret != 0) {
158 MS_LOG(EXCEPTION) << "failed to get attr from input[" << input_idx << "] for primitive " << info_.op_type;
159 }
160 get_input_attr_cache_[{input_idx, old_value}] = *ms_value;
161 }
162
InitAttrMap()163 void GeAdapterInfo::InitAttrMap() {
164 info_.attr_map.clear();
165 for (const auto &[k, v] : adapter_->getAttrMap()) {
166 (void)info_.attr_map.emplace(k, v.name);
167 }
168 }
169
InitInputToAttrMap()170 void GeAdapterInfo::InitInputToAttrMap() {
171 info_.input_attr_map.clear();
172 for (const auto &[k, v] : adapter_->getInputAttrMap()) {
173 (void)info_.input_attr_map.emplace(k - 1, v.name);
174 }
175 }
176
InitAttrToInputMap()177 void GeAdapterInfo::InitAttrToInputMap() {
178 auto attr_input_map = adapter_->getAttrInputMap();
179 auto input_map = adapter_->getInputMap();
180 for (const auto &[ms_attr_name, ge_input_name] : attr_input_map) {
181 const auto &ge_input_name_for_cpp17 = ge_input_name;
182 auto iter = std::find_if(input_map.begin(), input_map.end(), [&ge_input_name_for_cpp17](const auto &desc) {
183 return desc.second.name == ge_input_name_for_cpp17;
184 });
185 if (iter == input_map.end()) {
186 MS_LOG(EXCEPTION) << "Error adapter register of" << ms_attr_name << " and " << ge_input_name
187 << ", type: " << adapter_->getOpType();
188 }
189 (void)info_.attr_input_map.emplace(IntToSize(iter->first - 1), ms_attr_name);
190 }
191 }
192
InitInfo()193 void GeAdapterInfo::InitInfo() {
194 InitOpType();
195
196 InitInputSupportedDataType();
197 InitOutputSupportedDataType();
198
199 InitAttrMap();
200 InitInputToAttrMap();
201 InitAttrToInputMap();
202
203 InitAclInputsAndOutputs();
204 InitRefMap();
205 MS_LOG(DEBUG) << "INIT INFO:" << info_.op_type << " -- " << info_.input_supported_dtypes[0] << " --- "
206 << info_.output_supported_dtypes[0];
207 }
208
GetInstance()209 GeAdapterManager &GeAdapterManager::GetInstance() {
210 static GeAdapterManager instance;
211 return instance;
212 }
213
GetInfo(const std::string & prim_name,bool is_training=true)214 GeAdapterInfoPtr GeAdapterManager::GetInfo(const std::string &prim_name, bool is_training = true) {
215 std::lock_guard<std::mutex> guard(lock_);
216 auto iter = op_cache_.find(prim_name);
217 if (iter != op_cache_.end()) {
218 return iter->second;
219 }
220
221 OpAdapterPtr adpt = FindAdapter(prim_name, is_training);
222 if (adpt == nullptr) {
223 MS_LOG(DEBUG) << "The current name '" << prim_name << "' needs to add adapter.";
224 return nullptr;
225 }
226 if (prim_name != adpt->getOpType()) {
227 MS_LOG(DEBUG) << "Note: primitive name is difference with adapter: prim name: " << prim_name
228 << ", ge name: " << adpt->getOpType();
229 }
230 auto info_ptr = std::make_shared<GeAdapterInfo>(adpt);
231 info_ptr->InitInfo();
232 op_cache_[prim_name] = info_ptr;
233 return info_ptr;
234 }
235 } // namespace transform
236 } // namespace mindspore
237