• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "transform/acl_ir/ge_adapter_info.h"
18 #include <algorithm>
19 #include <limits>
20 #include "include/transform/graph_ir/utils.h"
21 #include "transform/graph_ir/transform_util.h"
22 #include "graph/operator_factory.h"
23 
24 namespace mindspore {
25 namespace transform {
InitOpType()26 void GeAdapterInfo::InitOpType() { info_.op_type = adapter_->getOpType(); }
27 
InitAclInputsAndOutputs()28 void GeAdapterInfo::InitAclInputsAndOutputs() {
29   InitParametersMap(adapter_->getInputMap(), adapter_->getDynInputMap(), true);
30   InitParametersMap(adapter_->getOutputMap(), adapter_->getDynOutputMap(), false);
31 }
32 
InitRefMap()33 void GeAdapterInfo::InitRefMap() {
34   for (const auto &[output_index, output_param_info] : info_.output_idx_ms2ge) {
35     for (const auto &[input_index, input_param_info] : info_.input_idx_ms2ge) {
36       if (output_param_info.name == input_param_info.name) {
37         (void)info_.ref_map_.emplace(IntToSize(output_index), IntToSize(input_index));
38         break;
39       }
40     }
41   }
42 }
43 
44 template <typename ParamMap, typename DynParamMap>
InitParametersMap(const ParamMap & params,const DynParamMap & dyn_params,bool is_input)45 void GeAdapterInfo::InitParametersMap(const ParamMap &params, const DynParamMap &dyn_params, bool is_input) {
46   auto &mapping_flags = is_input ? info_.input_mapping_flags : info_.output_mapping_flags;
47   auto &idx_ms2ge = is_input ? info_.input_idx_ms2ge : info_.output_idx_ms2ge;
48   auto &idx_ge2ms = is_input ? info_.input_idx_ge2ms : info_.output_idx_ge2ms;
49 
50   if (params.empty() && dyn_params.empty()) {
51     mapping_flags |= GeTensorInfo::kEmptyParam;
52     return;
53   }
54 
55   // calculate index of dynamic input/output
56   size_t ge_dynmaic_idx = std::numeric_limits<size_t>::max();
57   if (!dyn_params.empty()) {
58     if (dyn_params.size() > 1) {
59       MS_LOG(DEBUG) << "Op " << adapter_->getOpType() << " has " << dyn_params.size() << " dynamic "
60                     << (is_input ? "inputs" : "outputs");
61       mapping_flags |= GeTensorInfo::kMultiDynParam;
62     } else {
63       mapping_flags |= GeTensorInfo::kDynamicParam;
64       ge_dynmaic_idx = dyn_params.cbegin()->second.index;
65     }
66   }
67 
68   auto get_ms_idx = [is_input](int index) {
69     // for anf cnode, the 1st input is primitive name, so for input the real input index is `index - 1`
70     return is_input ? index - 1 : index;
71   };
72 
73   // process required/optional inputs or required outputs
74   for (const auto &[k, v] : params) {
75     int ms_idx = get_ms_idx(k);
76     uint32_t ge_idx = static_cast<uint32_t>(v.index);
77     // MindSpore Index --> GE Info
78     if constexpr (std::is_same<std::remove_cv_t<decltype(v)>, InputDesc>::value) {
79       idx_ms2ge[ms_idx] = Ms2GeParamInfo{
80         ge_idx, v.name, v.type == InputDesc::OPTIONAL ? Ms2GeParamInfo::OPTIONAL : Ms2GeParamInfo::REQUIRED,
81         ge_idx > ge_dynmaic_idx};
82     } else {
83       idx_ms2ge[ms_idx] = Ms2GeParamInfo{ge_idx, v.name, Ms2GeParamInfo::REQUIRED, ge_idx > ge_dynmaic_idx};
84     }
85 
86     // input/output: GE(GraphEngine) Index --> MindSpore Index
87     idx_ge2ms[ge_idx] = ms_idx;
88     if (is_input) {
89       max_input_ms_proto_idx_ = std::max(max_input_ms_proto_idx_, ms_idx);
90     }
91   }
92 
93   // process dynamic inputs/outputs
94   for (const auto &[k, v] : dyn_params) {
95     int ms_idx = get_ms_idx(k);
96     uint32_t ge_idx = static_cast<uint32_t>(v.index);
97     // MindSpore Index --> GE Info
98     idx_ms2ge[ms_idx] = Ms2GeParamInfo{ge_idx, v.name, Ms2GeParamInfo::DYNAMIC, ge_idx > ge_dynmaic_idx};
99     // input/output: GE(GraphEngine) Index --> MindSpore Index
100     idx_ge2ms[ge_idx] = ms_idx;
101     if (is_input) {
102       max_input_ms_proto_idx_ = std::max(max_input_ms_proto_idx_, ms_idx);
103     }
104   }
105 }
106 
InitInputSupportedDataType()107 void GeAdapterInfo::InitInputSupportedDataType() {
108   info_.input_supported_dtypes.clear();
109   for (const auto &[k, v] : adapter_->getInputMap()) {
110     (void)info_.input_supported_dtypes.emplace(k - 1, v.supported_dtypes);
111   }
112   for (const auto &[k, v] : adapter_->getDynInputMap()) {
113     (void)info_.input_supported_dtypes.emplace(k - 1, v.supported_dtypes);
114   }
115 }
116 
InitOutputSupportedDataType()117 void GeAdapterInfo::InitOutputSupportedDataType() {
118   info_.output_supported_dtypes.clear();
119   for (const auto &[k, v] : adapter_->getOutputMap()) {
120     (void)info_.output_supported_dtypes.emplace(k, v.supported_dtypes);
121   }
122   for (const auto &[k, v] : adapter_->getDynOutputMap()) {
123     (void)info_.output_supported_dtypes.emplace(k, v.supported_dtypes);
124   }
125 }
126 
GetGeAttrValueByMsAttrValue(const std::string & attr_name,ValuePtr * ms_value)127 void GeAdapterInfo::GetGeAttrValueByMsAttrValue(const std::string &attr_name, ValuePtr *ms_value) {
128   MS_EXCEPTION_IF_NULL(ms_value);
129   // class Value is a abstract class
130   auto iter = get_attr_cache_.find({attr_name, *ms_value});
131   if (iter != get_attr_cache_.end()) {
132     *ms_value = iter->second;
133     return;
134   }
135 
136   int ret = 0;
137   auto old_value = *ms_value;
138   ret = adapter_->getAttr(attr_name, ms_value);
139   if (ret != 0) {
140     MS_LOG(EXCEPTION) << "failed to get attr:" << attr_name << " for primitive " << info_.op_type;
141   }
142   get_attr_cache_[{attr_name, old_value}] = *ms_value;
143 }
144 
GetGeAttrValueByMsInputValue(const uint32_t & input_idx,ValuePtr * ms_value)145 void GeAdapterInfo::GetGeAttrValueByMsInputValue(const uint32_t &input_idx, ValuePtr *ms_value) {
146   MS_EXCEPTION_IF_NULL(ms_value);
147   // class Value is a abstract class
148   auto iter = get_input_attr_cache_.find({input_idx, *ms_value});
149   if (iter != get_input_attr_cache_.end()) {
150     *ms_value = iter->second;
151     return;
152   }
153 
154   int ret = 0;
155   auto old_value = *ms_value;
156   ret = adapter_->getAttr(input_idx, ms_value);
157   if (ret != 0) {
158     MS_LOG(EXCEPTION) << "failed to get attr from input[" << input_idx << "] for primitive " << info_.op_type;
159   }
160   get_input_attr_cache_[{input_idx, old_value}] = *ms_value;
161 }
162 
InitAttrMap()163 void GeAdapterInfo::InitAttrMap() {
164   info_.attr_map.clear();
165   for (const auto &[k, v] : adapter_->getAttrMap()) {
166     (void)info_.attr_map.emplace(k, v.name);
167   }
168 }
169 
InitInputToAttrMap()170 void GeAdapterInfo::InitInputToAttrMap() {
171   info_.input_attr_map.clear();
172   for (const auto &[k, v] : adapter_->getInputAttrMap()) {
173     (void)info_.input_attr_map.emplace(k - 1, v.name);
174   }
175 }
176 
InitAttrToInputMap()177 void GeAdapterInfo::InitAttrToInputMap() {
178   auto attr_input_map = adapter_->getAttrInputMap();
179   auto input_map = adapter_->getInputMap();
180   for (const auto &[ms_attr_name, ge_input_name] : attr_input_map) {
181     const auto &ge_input_name_for_cpp17 = ge_input_name;
182     auto iter = std::find_if(input_map.begin(), input_map.end(), [&ge_input_name_for_cpp17](const auto &desc) {
183       return desc.second.name == ge_input_name_for_cpp17;
184     });
185     if (iter == input_map.end()) {
186       MS_LOG(EXCEPTION) << "Error adapter register of" << ms_attr_name << " and " << ge_input_name
187                         << ", type: " << adapter_->getOpType();
188     }
189     (void)info_.attr_input_map.emplace(IntToSize(iter->first - 1), ms_attr_name);
190   }
191 }
192 
InitInfo()193 void GeAdapterInfo::InitInfo() {
194   InitOpType();
195 
196   InitInputSupportedDataType();
197   InitOutputSupportedDataType();
198 
199   InitAttrMap();
200   InitInputToAttrMap();
201   InitAttrToInputMap();
202 
203   InitAclInputsAndOutputs();
204   InitRefMap();
205   MS_LOG(DEBUG) << "INIT INFO:" << info_.op_type << " -- " << info_.input_supported_dtypes[0] << " --- "
206                 << info_.output_supported_dtypes[0];
207 }
208 
GetInstance()209 GeAdapterManager &GeAdapterManager::GetInstance() {
210   static GeAdapterManager instance;
211   return instance;
212 }
213 
GetInfo(const std::string & prim_name,bool is_training=true)214 GeAdapterInfoPtr GeAdapterManager::GetInfo(const std::string &prim_name, bool is_training = true) {
215   std::lock_guard<std::mutex> guard(lock_);
216   auto iter = op_cache_.find(prim_name);
217   if (iter != op_cache_.end()) {
218     return iter->second;
219   }
220 
221   OpAdapterPtr adpt = FindAdapter(prim_name, is_training);
222   if (adpt == nullptr) {
223     MS_LOG(DEBUG) << "The current name '" << prim_name << "' needs to add adapter.";
224     return nullptr;
225   }
226   if (prim_name != adpt->getOpType()) {
227     MS_LOG(DEBUG) << "Note: primitive name is difference with adapter: prim name: " << prim_name
228                   << ", ge name: " << adpt->getOpType();
229   }
230   auto info_ptr = std::make_shared<GeAdapterInfo>(adpt);
231   info_ptr->InitInfo();
232   op_cache_[prim_name] = info_ptr;
233   return info_ptr;
234 }
235 }  // namespace transform
236 }  // namespace mindspore
237