• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/backend/optimizer/op_adaptation_info_factory.h"
17 
18 #include "kernel/oplib/oplib.h"
19 #include "utils/log_adapter.h"
20 #include "include/common/utils/anfalgo.h"
21 #include "include/common/utils/convert_utils.h"
22 #include "include/backend/optimizer/helper.h"
23 #include "ops/framework_ops.h"
24 
25 namespace mindspore::opt {
set_backend_op_name(const std::string & default_op_name)26 OpAdaptationInfo &OpAdaptationInfo::set_backend_op_name(const std::string &default_op_name) {
27   backend_op_name_ = default_op_name;
28   return *this;
29 }
30 
set_target_op_name(const std::string & target_op_name)31 OpAdaptationInfo &OpAdaptationInfo::set_target_op_name(const std::string &target_op_name) {
32   target_op_name_ = target_op_name;
33   return *this;
34 }
35 
set_pre_check_func(std::function<bool (CNodePtr)> pre_check_func)36 OpAdaptationInfo &OpAdaptationInfo::set_pre_check_func(std::function<bool(CNodePtr)> pre_check_func) {
37   pre_check_func_ = std::move(pre_check_func);
38   return *this;
39 }
40 
set_need_tbe_check_supported(bool need_tbe_check_supported)41 OpAdaptationInfo &OpAdaptationInfo::set_need_tbe_check_supported(bool need_tbe_check_supported) {
42   need_tbe_check_supported_ = need_tbe_check_supported;
43   return *this;
44 }
45 
set_input_attr_info(size_t input_index,const std::string & attr_data_type)46 OpAdaptationInfo &OpAdaptationInfo::set_input_attr_info(size_t input_index, const std::string &attr_data_type) {
47   auto find = input_attr_map_.find(input_index);
48   if (find != input_attr_map_.end()) {
49     MS_LOG(ERROR) << "This input index (" << input_index << ")"
50                   << " has been registered.";
51     return *this;
52   }
53   input_attr_map_[input_index] = attr_data_type;
54   return *this;
55 }
56 
set_is_ascend_mindir()57 OpAdaptationInfo &OpAdaptationInfo::set_is_ascend_mindir() {
58   is_ascend_mindir_ = true;
59   return *this;
60 }
61 
GetInstance()62 OpAdaptationInfoRegister &OpAdaptationInfoRegister::GetInstance() {
63   static OpAdaptationInfoRegister inst;
64   return inst;
65 }
66 
GenerateKey(const std::string & me_op_name,const std::string & device_name,bool flag)67 std::string OpAdaptationInfoRegister::GenerateKey(const std::string &me_op_name, const std::string &device_name,
68                                                   bool flag) {
69   if (device_name != kCPUDevice && device_name != kGPUDevice && device_name != kAscendDevice) {
70     MS_LOG(EXCEPTION) << "Backend type is invalid, should be one of [" << kCPUDevice << ", " << kGPUDevice << ", "
71                       << kAscendDevice << "], but got " << device_name;
72   }
73 
74   std::string flag_str = flag ? "true" : "false";
75   return std::string(me_op_name + device_name + flag_str);
76 }
77 
GetOpName()78 std::set<std::string> &OpAdaptationInfoRegister::GetOpName() {
79   static std::set<std::string> op_names;
80   return op_names;
81 }
82 
GetOpInfoMap()83 std::map<std::string, OpAdaptationInfo *> &OpAdaptationInfoRegister::GetOpInfoMap() {
84   static std::map<std::string, OpAdaptationInfo *> op_info_map;
85   return op_info_map;
86 }
87 
RegOpAdaptationInfo(OpAdaptationInfo * reg_info)88 void OpAdaptationInfoRegister::RegOpAdaptationInfo(OpAdaptationInfo *reg_info) {
89   MS_EXCEPTION_IF_NULL(reg_info);
90   (void)GetOpName().insert(reg_info->me_op_name());
91   auto key = GenerateKey(reg_info->me_op_name(), reg_info->device_name(), reg_info->flag());
92   auto find = GetOpInfoMap().find(key);
93   if (find != GetOpInfoMap().end()) {
94     MS_LOG(DEBUG) << "This key (" << key << ") has been registered in me op info map.";
95     return;
96   }
97   MS_LOG(DEBUG) << "Reg op adaptation info to factory, key: " << key;
98   GetOpInfoMap()[key] = reg_info;
99 }
100 
GetOpAdaptationInfo(const std::string & me_op_name,const std::string & device_name,bool flag)101 OpAdaptationInfo *OpAdaptationInfoRegister::GetOpAdaptationInfo(const std::string &me_op_name,
102                                                                 const std::string &device_name, bool flag) {
103   auto name_iter = GetOpName().find(me_op_name);
104   if (name_iter == GetOpName().end()) {
105     return nullptr;
106   }
107   auto key = GenerateKey(me_op_name, device_name, flag);
108   auto iter = GetOpInfoMap().find(key);
109   if (iter == GetOpInfoMap().end()) {
110     MS_LOG(DEBUG) << "Can't find op adaptation for op " << me_op_name << " on " << device_name << " when flag is "
111                   << flag;
112     return nullptr;
113   }
114   return iter->second;
115 }
116 
CreateTargetOp(const CNodePtr & origin_op,const OpAdaptationInfo & op_adaptation_info)117 CNodePtr OpAdaptationInfoRegister::CreateTargetOp(const CNodePtr &origin_op,
118                                                   const OpAdaptationInfo &op_adaptation_info) {
119   MS_EXCEPTION_IF_NULL(origin_op);
120   auto target_op_name = op_adaptation_info.target_op_name();
121   auto input_attr_info_map = op_adaptation_info.input_attr_map();
122 
123   auto origin_primitive = GetCNodePrimitive(origin_op);
124   MS_EXCEPTION_IF_NULL(origin_primitive);
125   auto target_primitive = std::make_shared<Primitive>(target_op_name);
126   MS_EXCEPTION_IF_NULL(target_primitive);
127   (void)target_primitive->SetAttrs(origin_primitive->attrs());
128   std::vector<AnfNodePtr> target_inputs;
129   auto inputs = origin_op->inputs();
130   target_inputs.push_back(inputs[0]);
131   auto graph = origin_op->func_graph();
132   bool ir_change = false;
133   for (size_t i = 0; i < inputs.size() - 1; ++i) {
134     auto input_node = inputs[i + 1];
135     MS_EXCEPTION_IF_NULL(input_node);
136     if (IsPrimitiveCNode(input_node, prim::kPrimDepend)) {
137       input_node = AnfUtils::VisitKernel(input_node, 0).first;
138     }
139 
140     auto iter = input_attr_info_map.find(i);
141     if (iter != input_attr_info_map.end()) {
142       auto is_value_node = input_node->isa<ValueNode>();
143       auto is_monad = HasAbstractMonad(input_node);
144       if (!is_value_node || is_monad) {
145         MS_LOG(INFO) << "Convert " << origin_op->fullname_with_scope() << "'s input " << i
146                      << " to attr failed. input is value node: " << is_value_node << ", is monad: " << is_monad;
147         return nullptr;
148       }
149 
150       auto ret = ConvertInputToAttr(origin_op, i, input_node, iter->second, target_primitive);
151       if (!ret) {
152         MS_LOG(INFO) << "Convert " << origin_op->fullname_with_scope() << "'s input " << i << " to attr failed.";
153         return nullptr;
154       }
155       auto kernel_graph = graph->cast<KernelGraphPtr>();
156       MS_EXCEPTION_IF_NULL(kernel_graph);
157       (void)kernel_graph->RemoveValueNodeFromGraph(input_node->cast<ValueNodePtr>());
158       ir_change = true;
159     } else {
160       target_inputs.push_back(inputs[i + 1]);
161     }
162   }
163 
164   // Update target_op's inputs
165   target_inputs[0] = NewValueNode(target_primitive);
166   MS_EXCEPTION_IF_NULL(graph);
167   auto target_op = opt::NewCNode(target_inputs, graph, {origin_op});
168   MS_EXCEPTION_IF_NULL(target_op);
169   target_op->set_abstract(origin_op->abstract());
170   target_op->set_scope(origin_op->scope());
171   target_op->set_primal_attrs(origin_op->primal_attrs());
172   target_op->set_attrs(origin_op->attrs());
173   target_op->set_primal_debug_infos(origin_op->primal_debug_infos());
174   common::AnfAlgo::EraseNodeAttr(kAttrIsKernelDynamicImpl, target_op);
175   if (common::AnfAlgo::HasNodeAttr(kAttrCustAicpu, origin_op)) {
176     common::AnfAlgo::CopyNodeAttr(kAttrCustAicpu, origin_op, target_op);
177   }
178 
179   common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), target_op);
180   common::AnfAlgo::SetNodeAttr(kAttrMeOpName, MakeValue(op_adaptation_info.me_op_name()), target_op);
181   common::AnfAlgo::SetNodeAttr(kAttrIRChange, MakeValue(ir_change), target_op);
182 
183   auto is_dynamic = common::AnfAlgo::IsDynamicShape(origin_op);
184   MS_LOG(DEBUG) << "Create op " << target_op->fullname_with_scope() << ", debug string:" << target_op->DebugString()
185                 << ", attr text:" << target_primitive->GetAttrsText() << " from " << origin_op->fullname_with_scope()
186                 << ", debug string:" << origin_op->DebugString() << ", attr text:" << origin_primitive->GetAttrsText()
187                 << ", is dynamic shape:" << is_dynamic;
188   return target_op;
189 }
190 
ConvertInputToAttr(const CNodePtr & origin_op,size_t i,const std::shared_ptr<AnfNode> & input_node,const std::string & attr_data_type,const std::shared_ptr<Primitive> & target_primitive)191 bool OpAdaptationInfoRegister::ConvertInputToAttr(const CNodePtr &origin_op, size_t i,
192                                                   const std::shared_ptr<AnfNode> &input_node,
193                                                   const std::string &attr_data_type,
194                                                   const std::shared_ptr<Primitive> &target_primitive) {
195   MS_EXCEPTION_IF_NULL(origin_op);
196   MS_EXCEPTION_IF_NULL(input_node);
197   MS_EXCEPTION_IF_NULL(target_primitive);
198   auto value_node = input_node->cast<ValueNodePtr>();
199   MS_EXCEPTION_IF_NULL(value_node);
200   MS_LOG(DEBUG) << "start erase input[" << i
201                 << "] of cnode[" + origin_op->DebugString() + "], origin value:" << value_node->ToString()
202                 << ", Type:" << value_node->type_name();
203 
204   auto value = value_node->value();
205   MS_EXCEPTION_IF_NULL(value);
206   if (value->isa<tensor::Tensor>()) {
207     auto tensor = value->cast<tensor::TensorPtr>();
208     if (tensor->data().const_data() == nullptr && !tensor->has_user_data(kTensorValueIsEmpty)) {
209       MS_LOG(DEBUG) << "Const input data ptr is null from op " << origin_op->fullname_with_scope() << "'s input " << i;
210       return false;
211     }
212     value = CreateValueFromTensor(tensor);
213     value = UpdateValueByAttrDataType(value, attr_data_type);
214     MS_LOG(DEBUG) << "new attr value:" << value_node->ToString() << ", Type:" << value_node->type_name();
215   }
216 
217   std::string attr_name = common::AnfAlgo::GetInputName(origin_op, i);
218   if (attr_name.empty()) {
219     MS_LOG(DEBUG) << "Attr name is empty.";
220     return false;
221   }
222 
223   if (origin_op->HasAttr(attr_name)) {
224     auto origin_primitive = GetCNodePrimitive(origin_op);
225     MS_EXCEPTION_IF_NULL(origin_primitive);
226     MS_LOG(ERROR) << "Origin op already has this attr " << attr_name
227                   << ". op attrs:" << origin_primitive->GetAttrsText() << ". DebugString:" << origin_op->DebugString();
228     return false;
229   }
230 
231   target_primitive->set_attr(attr_name, value);
232   return true;
233 }
234 
RenamePrimitiveName(const CNodePtr & origin_op,const string & me_op_name,const string & backend_op_name)235 void OpAdaptationInfoRegister::RenamePrimitiveName(const CNodePtr &origin_op, const string &me_op_name,
236                                                    const string &backend_op_name) {
237   MS_EXCEPTION_IF_NULL(origin_op);
238   if (backend_op_name == me_op_name) {
239     return;
240   }
241   auto primitive = GetCNodePrimitive(origin_op);
242   MS_EXCEPTION_IF_NULL(primitive);
243   primitive->set_name(backend_op_name);
244   // reset full scope name
245   origin_op->set_fullname_with_scope("");
246   MS_LOG(INFO) << "Rename op type from " << me_op_name << " to " << backend_op_name << " for op "
247                << origin_op->fullname_with_scope();
248   if (me_op_name == kSparseGatherV2OpName) {
249     common::AnfAlgo::SetNodeAttr(kAttrIsSparse, MakeValue(true), origin_op);
250   }
251   common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), origin_op);
252 }
253 
RegisterHelper(const string & me_op_name,const string & device_name,bool flag,int len,...)254 RegisterHelper::RegisterHelper(const string &me_op_name, const string &device_name, bool flag, int len, ...) {
255   mindspore::HashSet<size_t> input_to_attr;
256   input_to_attr.reserve(static_cast<size_t>(IntToUint(len)));
257   va_list var_ptr;
258   va_start(var_ptr, len);
259   for (int i = 0; i < len; ++i) {
260     (void)input_to_attr.insert(static_cast<size_t>(IntToUint(va_arg(var_ptr, int))));
261   }
262   va_end(var_ptr);
263   op_adaptation_info_ = std::make_shared<OpAdaptationInfo>(me_op_name, device_name, flag);
264   MS_EXCEPTION_IF_NULL(op_adaptation_info_);
265   for (auto &index : input_to_attr) {
266     (void)op_adaptation_info_->set_input_attr_info(index);
267   }
268   opt::OpAdaptationInfoRegister::GetInstance().RegOpAdaptationInfo(op_adaptation_info_.get());
269 }
270 
RegisterHelper(const OpAdaptationInfo & op_adaptation_info)271 RegisterHelper::RegisterHelper(const OpAdaptationInfo &op_adaptation_info) {
272   op_adaptation_info_ = std::make_shared<OpAdaptationInfo>(op_adaptation_info);
273   MS_EXCEPTION_IF_NULL(op_adaptation_info_);
274   opt::OpAdaptationInfoRegister::GetInstance().RegOpAdaptationInfo(op_adaptation_info_.get());
275 }
operator =(const OpAdaptationInfo & op_adaptation_info)276 OpAdaptationInfo &OpAdaptationInfo::operator=(const OpAdaptationInfo &op_adaptation_info) {
277   if (this == &op_adaptation_info) {
278     return *this;
279   }
280   me_op_name_ = op_adaptation_info.me_op_name_;
281   backend_op_name_ = op_adaptation_info.backend_op_name_;
282   target_op_name_ = op_adaptation_info.target_op_name_;
283   pre_check_func_ = op_adaptation_info.pre_check_func_;
284   need_tbe_check_supported_ = op_adaptation_info.need_tbe_check_supported_;
285   input_attr_map_ = op_adaptation_info.input_attr_map_;
286   device_name_ = op_adaptation_info.device_name_;
287   flag_ = op_adaptation_info.flag_;
288   return *this;
289 }
290 }  // namespace mindspore::opt
291