1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/backend/optimizer/op_adaptation_info_factory.h"
17
18 #include "kernel/oplib/oplib.h"
19 #include "utils/log_adapter.h"
20 #include "include/common/utils/anfalgo.h"
21 #include "include/common/utils/convert_utils.h"
22 #include "include/backend/optimizer/helper.h"
23 #include "ops/framework_ops.h"
24
25 namespace mindspore::opt {
set_backend_op_name(const std::string & default_op_name)26 OpAdaptationInfo &OpAdaptationInfo::set_backend_op_name(const std::string &default_op_name) {
27 backend_op_name_ = default_op_name;
28 return *this;
29 }
30
set_target_op_name(const std::string & target_op_name)31 OpAdaptationInfo &OpAdaptationInfo::set_target_op_name(const std::string &target_op_name) {
32 target_op_name_ = target_op_name;
33 return *this;
34 }
35
set_pre_check_func(std::function<bool (CNodePtr)> pre_check_func)36 OpAdaptationInfo &OpAdaptationInfo::set_pre_check_func(std::function<bool(CNodePtr)> pre_check_func) {
37 pre_check_func_ = std::move(pre_check_func);
38 return *this;
39 }
40
set_need_tbe_check_supported(bool need_tbe_check_supported)41 OpAdaptationInfo &OpAdaptationInfo::set_need_tbe_check_supported(bool need_tbe_check_supported) {
42 need_tbe_check_supported_ = need_tbe_check_supported;
43 return *this;
44 }
45
set_input_attr_info(size_t input_index,const std::string & attr_data_type)46 OpAdaptationInfo &OpAdaptationInfo::set_input_attr_info(size_t input_index, const std::string &attr_data_type) {
47 auto find = input_attr_map_.find(input_index);
48 if (find != input_attr_map_.end()) {
49 MS_LOG(ERROR) << "This input index (" << input_index << ")"
50 << " has been registered.";
51 return *this;
52 }
53 input_attr_map_[input_index] = attr_data_type;
54 return *this;
55 }
56
set_is_ascend_mindir()57 OpAdaptationInfo &OpAdaptationInfo::set_is_ascend_mindir() {
58 is_ascend_mindir_ = true;
59 return *this;
60 }
61
GetInstance()62 OpAdaptationInfoRegister &OpAdaptationInfoRegister::GetInstance() {
63 static OpAdaptationInfoRegister inst;
64 return inst;
65 }
66
GenerateKey(const std::string & me_op_name,const std::string & device_name,bool flag)67 std::string OpAdaptationInfoRegister::GenerateKey(const std::string &me_op_name, const std::string &device_name,
68 bool flag) {
69 if (device_name != kCPUDevice && device_name != kGPUDevice && device_name != kAscendDevice) {
70 MS_LOG(EXCEPTION) << "Backend type is invalid, should be one of [" << kCPUDevice << ", " << kGPUDevice << ", "
71 << kAscendDevice << "], but got " << device_name;
72 }
73
74 std::string flag_str = flag ? "true" : "false";
75 return std::string(me_op_name + device_name + flag_str);
76 }
77
GetOpName()78 std::set<std::string> &OpAdaptationInfoRegister::GetOpName() {
79 static std::set<std::string> op_names;
80 return op_names;
81 }
82
GetOpInfoMap()83 std::map<std::string, OpAdaptationInfo *> &OpAdaptationInfoRegister::GetOpInfoMap() {
84 static std::map<std::string, OpAdaptationInfo *> op_info_map;
85 return op_info_map;
86 }
87
RegOpAdaptationInfo(OpAdaptationInfo * reg_info)88 void OpAdaptationInfoRegister::RegOpAdaptationInfo(OpAdaptationInfo *reg_info) {
89 MS_EXCEPTION_IF_NULL(reg_info);
90 (void)GetOpName().insert(reg_info->me_op_name());
91 auto key = GenerateKey(reg_info->me_op_name(), reg_info->device_name(), reg_info->flag());
92 auto find = GetOpInfoMap().find(key);
93 if (find != GetOpInfoMap().end()) {
94 MS_LOG(DEBUG) << "This key (" << key << ") has been registered in me op info map.";
95 return;
96 }
97 MS_LOG(DEBUG) << "Reg op adaptation info to factory, key: " << key;
98 GetOpInfoMap()[key] = reg_info;
99 }
100
GetOpAdaptationInfo(const std::string & me_op_name,const std::string & device_name,bool flag)101 OpAdaptationInfo *OpAdaptationInfoRegister::GetOpAdaptationInfo(const std::string &me_op_name,
102 const std::string &device_name, bool flag) {
103 auto name_iter = GetOpName().find(me_op_name);
104 if (name_iter == GetOpName().end()) {
105 return nullptr;
106 }
107 auto key = GenerateKey(me_op_name, device_name, flag);
108 auto iter = GetOpInfoMap().find(key);
109 if (iter == GetOpInfoMap().end()) {
110 MS_LOG(DEBUG) << "Can't find op adaptation for op " << me_op_name << " on " << device_name << " when flag is "
111 << flag;
112 return nullptr;
113 }
114 return iter->second;
115 }
116
CreateTargetOp(const CNodePtr & origin_op,const OpAdaptationInfo & op_adaptation_info)117 CNodePtr OpAdaptationInfoRegister::CreateTargetOp(const CNodePtr &origin_op,
118 const OpAdaptationInfo &op_adaptation_info) {
119 MS_EXCEPTION_IF_NULL(origin_op);
120 auto target_op_name = op_adaptation_info.target_op_name();
121 auto input_attr_info_map = op_adaptation_info.input_attr_map();
122
123 auto origin_primitive = GetCNodePrimitive(origin_op);
124 MS_EXCEPTION_IF_NULL(origin_primitive);
125 auto target_primitive = std::make_shared<Primitive>(target_op_name);
126 MS_EXCEPTION_IF_NULL(target_primitive);
127 (void)target_primitive->SetAttrs(origin_primitive->attrs());
128 std::vector<AnfNodePtr> target_inputs;
129 auto inputs = origin_op->inputs();
130 target_inputs.push_back(inputs[0]);
131 auto graph = origin_op->func_graph();
132 bool ir_change = false;
133 for (size_t i = 0; i < inputs.size() - 1; ++i) {
134 auto input_node = inputs[i + 1];
135 MS_EXCEPTION_IF_NULL(input_node);
136 if (IsPrimitiveCNode(input_node, prim::kPrimDepend)) {
137 input_node = AnfUtils::VisitKernel(input_node, 0).first;
138 }
139
140 auto iter = input_attr_info_map.find(i);
141 if (iter != input_attr_info_map.end()) {
142 auto is_value_node = input_node->isa<ValueNode>();
143 auto is_monad = HasAbstractMonad(input_node);
144 if (!is_value_node || is_monad) {
145 MS_LOG(INFO) << "Convert " << origin_op->fullname_with_scope() << "'s input " << i
146 << " to attr failed. input is value node: " << is_value_node << ", is monad: " << is_monad;
147 return nullptr;
148 }
149
150 auto ret = ConvertInputToAttr(origin_op, i, input_node, iter->second, target_primitive);
151 if (!ret) {
152 MS_LOG(INFO) << "Convert " << origin_op->fullname_with_scope() << "'s input " << i << " to attr failed.";
153 return nullptr;
154 }
155 auto kernel_graph = graph->cast<KernelGraphPtr>();
156 MS_EXCEPTION_IF_NULL(kernel_graph);
157 (void)kernel_graph->RemoveValueNodeFromGraph(input_node->cast<ValueNodePtr>());
158 ir_change = true;
159 } else {
160 target_inputs.push_back(inputs[i + 1]);
161 }
162 }
163
164 // Update target_op's inputs
165 target_inputs[0] = NewValueNode(target_primitive);
166 MS_EXCEPTION_IF_NULL(graph);
167 auto target_op = opt::NewCNode(target_inputs, graph, {origin_op});
168 MS_EXCEPTION_IF_NULL(target_op);
169 target_op->set_abstract(origin_op->abstract());
170 target_op->set_scope(origin_op->scope());
171 target_op->set_primal_attrs(origin_op->primal_attrs());
172 target_op->set_attrs(origin_op->attrs());
173 target_op->set_primal_debug_infos(origin_op->primal_debug_infos());
174 common::AnfAlgo::EraseNodeAttr(kAttrIsKernelDynamicImpl, target_op);
175 if (common::AnfAlgo::HasNodeAttr(kAttrCustAicpu, origin_op)) {
176 common::AnfAlgo::CopyNodeAttr(kAttrCustAicpu, origin_op, target_op);
177 }
178
179 common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), target_op);
180 common::AnfAlgo::SetNodeAttr(kAttrMeOpName, MakeValue(op_adaptation_info.me_op_name()), target_op);
181 common::AnfAlgo::SetNodeAttr(kAttrIRChange, MakeValue(ir_change), target_op);
182
183 auto is_dynamic = common::AnfAlgo::IsDynamicShape(origin_op);
184 MS_LOG(DEBUG) << "Create op " << target_op->fullname_with_scope() << ", debug string:" << target_op->DebugString()
185 << ", attr text:" << target_primitive->GetAttrsText() << " from " << origin_op->fullname_with_scope()
186 << ", debug string:" << origin_op->DebugString() << ", attr text:" << origin_primitive->GetAttrsText()
187 << ", is dynamic shape:" << is_dynamic;
188 return target_op;
189 }
190
ConvertInputToAttr(const CNodePtr & origin_op,size_t i,const std::shared_ptr<AnfNode> & input_node,const std::string & attr_data_type,const std::shared_ptr<Primitive> & target_primitive)191 bool OpAdaptationInfoRegister::ConvertInputToAttr(const CNodePtr &origin_op, size_t i,
192 const std::shared_ptr<AnfNode> &input_node,
193 const std::string &attr_data_type,
194 const std::shared_ptr<Primitive> &target_primitive) {
195 MS_EXCEPTION_IF_NULL(origin_op);
196 MS_EXCEPTION_IF_NULL(input_node);
197 MS_EXCEPTION_IF_NULL(target_primitive);
198 auto value_node = input_node->cast<ValueNodePtr>();
199 MS_EXCEPTION_IF_NULL(value_node);
200 MS_LOG(DEBUG) << "start erase input[" << i
201 << "] of cnode[" + origin_op->DebugString() + "], origin value:" << value_node->ToString()
202 << ", Type:" << value_node->type_name();
203
204 auto value = value_node->value();
205 MS_EXCEPTION_IF_NULL(value);
206 if (value->isa<tensor::Tensor>()) {
207 auto tensor = value->cast<tensor::TensorPtr>();
208 if (tensor->data().const_data() == nullptr && !tensor->has_user_data(kTensorValueIsEmpty)) {
209 MS_LOG(DEBUG) << "Const input data ptr is null from op " << origin_op->fullname_with_scope() << "'s input " << i;
210 return false;
211 }
212 value = CreateValueFromTensor(tensor);
213 value = UpdateValueByAttrDataType(value, attr_data_type);
214 MS_LOG(DEBUG) << "new attr value:" << value_node->ToString() << ", Type:" << value_node->type_name();
215 }
216
217 std::string attr_name = common::AnfAlgo::GetInputName(origin_op, i);
218 if (attr_name.empty()) {
219 MS_LOG(DEBUG) << "Attr name is empty.";
220 return false;
221 }
222
223 if (origin_op->HasAttr(attr_name)) {
224 auto origin_primitive = GetCNodePrimitive(origin_op);
225 MS_EXCEPTION_IF_NULL(origin_primitive);
226 MS_LOG(ERROR) << "Origin op already has this attr " << attr_name
227 << ". op attrs:" << origin_primitive->GetAttrsText() << ". DebugString:" << origin_op->DebugString();
228 return false;
229 }
230
231 target_primitive->set_attr(attr_name, value);
232 return true;
233 }
234
RenamePrimitiveName(const CNodePtr & origin_op,const string & me_op_name,const string & backend_op_name)235 void OpAdaptationInfoRegister::RenamePrimitiveName(const CNodePtr &origin_op, const string &me_op_name,
236 const string &backend_op_name) {
237 MS_EXCEPTION_IF_NULL(origin_op);
238 if (backend_op_name == me_op_name) {
239 return;
240 }
241 auto primitive = GetCNodePrimitive(origin_op);
242 MS_EXCEPTION_IF_NULL(primitive);
243 primitive->set_name(backend_op_name);
244 // reset full scope name
245 origin_op->set_fullname_with_scope("");
246 MS_LOG(INFO) << "Rename op type from " << me_op_name << " to " << backend_op_name << " for op "
247 << origin_op->fullname_with_scope();
248 if (me_op_name == kSparseGatherV2OpName) {
249 common::AnfAlgo::SetNodeAttr(kAttrIsSparse, MakeValue(true), origin_op);
250 }
251 common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), origin_op);
252 }
253
RegisterHelper(const string & me_op_name,const string & device_name,bool flag,int len,...)254 RegisterHelper::RegisterHelper(const string &me_op_name, const string &device_name, bool flag, int len, ...) {
255 mindspore::HashSet<size_t> input_to_attr;
256 input_to_attr.reserve(static_cast<size_t>(IntToUint(len)));
257 va_list var_ptr;
258 va_start(var_ptr, len);
259 for (int i = 0; i < len; ++i) {
260 (void)input_to_attr.insert(static_cast<size_t>(IntToUint(va_arg(var_ptr, int))));
261 }
262 va_end(var_ptr);
263 op_adaptation_info_ = std::make_shared<OpAdaptationInfo>(me_op_name, device_name, flag);
264 MS_EXCEPTION_IF_NULL(op_adaptation_info_);
265 for (auto &index : input_to_attr) {
266 (void)op_adaptation_info_->set_input_attr_info(index);
267 }
268 opt::OpAdaptationInfoRegister::GetInstance().RegOpAdaptationInfo(op_adaptation_info_.get());
269 }
270
RegisterHelper(const OpAdaptationInfo & op_adaptation_info)271 RegisterHelper::RegisterHelper(const OpAdaptationInfo &op_adaptation_info) {
272 op_adaptation_info_ = std::make_shared<OpAdaptationInfo>(op_adaptation_info);
273 MS_EXCEPTION_IF_NULL(op_adaptation_info_);
274 opt::OpAdaptationInfoRegister::GetInstance().RegOpAdaptationInfo(op_adaptation_info_.get());
275 }
operator =(const OpAdaptationInfo & op_adaptation_info)276 OpAdaptationInfo &OpAdaptationInfo::operator=(const OpAdaptationInfo &op_adaptation_info) {
277 if (this == &op_adaptation_info) {
278 return *this;
279 }
280 me_op_name_ = op_adaptation_info.me_op_name_;
281 backend_op_name_ = op_adaptation_info.backend_op_name_;
282 target_op_name_ = op_adaptation_info.target_op_name_;
283 pre_check_func_ = op_adaptation_info.pre_check_func_;
284 need_tbe_check_supported_ = op_adaptation_info.need_tbe_check_supported_;
285 input_attr_map_ = op_adaptation_info.input_attr_map_;
286 device_name_ = op_adaptation_info.device_name_;
287 flag_ = op_adaptation_info.flag_;
288 return *this;
289 }
290 } // namespace mindspore::opt
291