• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <cstdint>
18 #include <functional>
19 #include <map>
20 #include <utility>
21 #include <vector>
22 #include <unordered_map>
23 #include "base/base.h"
24 #include "backend/common/graph_kernel/convert_input_and_attr.h"
25 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
26 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
27 #include "include/backend/anf_runtime_algorithm.h"
28 #include "include/backend/optimizer/helper.h"
29 #include "include/api/format.h"
30 #include "ops/auto_generate/gen_ops_primitive.h"
31 #include "ops/array_ops.h"
32 #include "ops/op_def.h"
33 #include "ops/op_utils.h"
34 #include "ops/sequence_ops.h"
35 #include "utils/anf_utils.h"
36 #include "utils/check_convert_utils.h"
37 
38 namespace mindspore::graphkernel {
39 namespace {
GetConvertInputAttrOps()40 const std::set<std::string> &GetConvertInputAttrOps() {
41   static const std::set<std::string> convert_input_attr_ops = {
42     prim::kPrimSoftmax->name(),       prim::kPrimReduceSum->name(),       prim::kPrimReduceMax->name(),
43     prim::kPrimReduceMin->name(),     prim::kPrimReduceMean->name(),      prim::kPrimOneHot->name(),
44     prim::kPrimMinimumGrad->name(),   prim::kPrimMaximumGrad->name(),     prim::kPrimGather->name(),
45     prim::kPrimCumSum->name(),        prim::kPrimArgmin->name(),          prim::kPrimArgmax->name(),
46     prim::kPrimBiasAdd->name(),       prim::kPrimBiasAddGrad->name(),     prim::kPrimLayerNorm->name(),
47     prim::kPrimLayerNormGrad->name(), prim::kPrimLogSoftmax->name(),      prim::kPrimLogSoftmaxGrad->name(),
48     prim::kPrimStridedSlice->name(),  prim::kPrimAdamWeightDecay->name(), prim::kPrimMatMul->name(),
49     prim::kPrimBatchMatMul->name(),
50   };
51   return convert_input_attr_ops;
52 }
53 
GetConvertKernelObjOps()54 const std::map<std::string, std::vector<size_t>> &GetConvertKernelObjOps() {
55   static const std::map<std::string, std::vector<size_t>> convert_kernel_obj_ops = {
56     {prim::kPrimReshape->name(), {2}},
57     {prim::kPrimReduceSum->name(), {2}},           // axis is tuple(int)
58     {prim::kPrimReduceMax->name(), {2}},           // axis is tuple(int)
59     {prim::kPrimReduceMin->name(), {2}},           // axis is tuple(int)
60     {prim::kPrimReduceMean->name(), {2}},          // axis is tuple(int)
61     {prim::kPrimStridedSlice->name(), {2, 3, 4}},  // begin, end, strides
62     {prim::kPrimTile->name(), {2}},
63     {prim::kPrimTranspose->name(), {2}},
64   };
65   return convert_kernel_obj_ops;
66 }
67 
EnumToFormat(const ValuePtr & value)68 ValuePtr EnumToFormat(const ValuePtr &value) {
69   if (!value->isa<Int64Imm>()) {
70     MS_LOG(EXCEPTION) << value->ToString() << " is not Int64Imm.";
71   }
72   auto val = GetValue<int64_t>(value);
73   if (val == Format::NCHW) {
74     return MakeValue("NCHW");
75   } else if (val == Format::NHWC) {
76     return MakeValue("NHWC");
77   } else if (val == Format::NCDHW) {
78     return MakeValue("NCDHW");
79   } else {
80     MS_LOG(EXCEPTION) << value->ToString() << " is unexpected.";
81   }
82 }
83 
FormatToEnum(const ValuePtr & value)84 ValuePtr FormatToEnum(const ValuePtr &value) {
85   auto format = GetValue<std::string>(value);
86   if (format == "NCHW") {
87     return MakeValue<int64_t>(Format::NCHW);
88   } else if (format == "NHWC") {
89     return MakeValue<int64_t>(Format::NHWC);
90   } else if (format == "NCDHW") {
91     return MakeValue<int64_t>(Format::NCDHW);
92   } else {
93     MS_LOG(EXCEPTION) << value->ToString() << " value:" << format << " is unexpected.";
94   }
95 }
96 
EnumToDtype(const ValuePtr & value)97 ValuePtr EnumToDtype(const ValuePtr &value) {
98   if (!value->isa<Int64Imm>()) {
99     MS_LOG(EXCEPTION) << value->ToString() << " is not Int64Imm.";
100   }
101   auto val = GetValue<int64_t>(value);
102   return TypeIdToType(static_cast<TypeId>(val));
103 }
104 
DtypeToEnum(const ValuePtr & value)105 ValuePtr DtypeToEnum(const ValuePtr &value) {
106   if (!value->isa<Type>()) {
107     MS_LOG(EXCEPTION) << value->ToString() << " is not Type.";
108   }
109   auto type_id = value->cast<TypePtr>()->type_id();
110   return MakeValue<int64_t>(type_id);
111 }
112 
113 using ArgHandlerFunc = std::function<ValuePtr(const ValuePtr &)>;
114 
GetArgHandlerFunc(const std::string & arg_handler)115 ArgHandlerFunc GetArgHandlerFunc(const std::string &arg_handler) {
116   static const std::unordered_map<std::string, ArgHandlerFunc> arg_handler_funcs = {
117     {"str_to_enum", EnumToFormat},
118     {"dtype_to_type_id", EnumToDtype},
119   };
120   if (arg_handler_funcs.find(arg_handler) != arg_handler_funcs.end()) {
121     return arg_handler_funcs.at(arg_handler);
122   } else {
123     return nullptr;
124   }
125 }
126 
GetOppArgHandlerFunc(const std::string & arg_handler)127 ArgHandlerFunc GetOppArgHandlerFunc(const std::string &arg_handler) {
128   static const std::unordered_map<std::string, ArgHandlerFunc> opp_arg_handler_funcs = {
129     {"str_to_enum", FormatToEnum},
130     {"dtype_to_type_id", DtypeToEnum},
131   };
132   if (opp_arg_handler_funcs.find(arg_handler) != opp_arg_handler_funcs.end()) {
133     return opp_arg_handler_funcs.at(arg_handler);
134   } else {
135     return nullptr;
136   }
137 }
138 }  // namespace
139 
AddConstInputToAttr(const CNodePtr & cnode,const size_t input_index,const std::string & arg_name,const std::string & arg_handler,const PrimitivePtr & primitive)140 void ConvertFrontEndToGraphKernel::AddConstInputToAttr(const CNodePtr &cnode, const size_t input_index,
141                                                        const std::string &arg_name, const std::string &arg_handler,
142                                                        const PrimitivePtr &primitive) {
143   if (input_index >= cnode->size() - 1) {
144     MS_LOG(EXCEPTION) << "The index of args in op_def `" << input_index
145                       << "` should less than the inputs size minus one `" << cnode->size() - 1 << "`.";
146   }
147   auto input_node = cnode->inputs()[input_index + 1];
148 
149   ValuePtr value = nullptr;
150   if (input_node->isa<ValueNode>()) {
151     auto value_node = input_node->cast<ValueNodePtr>();
152     value = value_node->value();
153   } else if (input_node->isa<Parameter>()) {
154     auto parameter_node = input_node->cast<ParameterPtr>();
155     value = parameter_node->abstract()->BuildValue();
156   }
157   if (value == nullptr) {
158     MS_LOG(EXCEPTION) << cnode->ToString() << " is not Value.";
159   }
160   if (value->isa<ValueAny>()) {
161     MS_LOG(EXCEPTION) << cnode->ToString() << " is ValueAny.";
162   }
163   if (!arg_handler.empty() && !value->isa<None>()) {
164     auto arg_handler_func = GetArgHandlerFunc(arg_handler);
165     MS_EXCEPTION_IF_NULL(arg_handler_func);
166     value = arg_handler_func(value);
167     primitive->AddAttr(arg_name, value);
168     return;
169   }
170 
171   if (!value->isa<tensor::Tensor>()) {
172     primitive->AddAttr(arg_name, value);
173     return;
174   }
175   auto value_vector = CheckAndConvertUtils::CheckTensorIntValue(arg_name, value, primitive->name());
176   auto tensor = value->cast<tensor::TensorPtr>();
177   auto tensor_shape = tensor->shape_c();
178   MS_LOG(DEBUG) << cnode->ToString() << " 's input[" << input_index << "] is tensor.";
179   if (tensor_shape.empty()) {
180     primitive->AddAttr(arg_name, MakeValue(value_vector[0]));
181   } else {
182     primitive->AddAttr(arg_name, MakeValue(value_vector));
183   }
184 }
185 
Process(const CNodePtr & cnode,const ops::OpDefPtr & op_def,const PrimitivePtr & primitive)186 bool ConvertFrontEndToGraphKernel::Process(const CNodePtr &cnode, const ops::OpDefPtr &op_def,
187                                            const PrimitivePtr &primitive) {
188   const auto &op_def_args = op_def->args_;
189   const auto &op_def_indexes = op_def->indexes_;
190   bool changed = false;
191   auto ori_input_size = AnfUtils::GetInputTensorNum(cnode);
192   if (op_def_args.size() != ori_input_size) {
193     MS_LOG(EXCEPTION) << "The size of args in op_def `" << op_def->args_.size()
194                       << "` should be equal to the inputs size minus one `" << ori_input_size << "`.";
195   }
196   auto iter = op_def_args.crbegin();
197   auto new_input_size = op_def_args.size();
198   for (; iter != op_def_args.crend(); ++iter, --new_input_size) {
199     // as_init_arg_ == 1 indicate the arg need convert, the arg need convert is at the tail of the list
200     if (iter->as_init_arg_ != 1) {
201       break;
202     }
203     const auto &arg_name = iter->arg_name_;
204     const auto &arg_handler = iter->arg_handler_;
205     MS_LOG(DEBUG) << cnode->ToString() << " convert input to attr: " << arg_name;
206     if (auto index_iter = op_def_indexes.find(arg_name); index_iter != op_def_indexes.end()) {
207       AddConstInputToAttr(cnode, index_iter->second, arg_name, arg_handler, primitive);
208       changed = true;
209     } else {
210       MS_LOG(EXCEPTION) << primitive->name() << " not found index of attr[" << arg_name << "] in op def indexes.";
211     }
212   }
213   auto inputs = cnode->inputs();
214   if (changed) {
215     // remainder args in op_def_arg is the size of new input args
216     AnfNodePtrList new_inputs(inputs.begin(), inputs.begin() + new_input_size + 1);
217     for (size_t i = ori_input_size; i < inputs.size() - 1; ++i) {
218       new_inputs.emplace_back(inputs[i + 1]);
219     }
220     cnode->set_inputs(new_inputs);
221     auto cb = Callback::Instance();
222     MS_EXCEPTION_IF_NULL(cb);
223     cb->ResetKernelInfoInputs(cnode, {});
224   }
225   return changed;
226 }
227 
Run(const FuncGraphPtr & func_graph)228 bool ConvertFrontEndToGraphKernel::Run(const FuncGraphPtr &func_graph) {
229   bool changed = false;
230   MS_EXCEPTION_IF_NULL(func_graph);
231   MS_EXCEPTION_IF_NULL(func_graph->get_return());
232   auto todos = TopoSort(func_graph->get_return());
233   for (auto &node : todos) {
234     if (!OpDefAdapter::NeedConvertInputAndAttr(node)) {
235       continue;
236     }
237     auto primitive = GetCNodePrimitive(node);
238     if (primitive == nullptr) {
239       continue;
240     }
241     const auto &op_name = primitive->name();
242     auto op_def = mindspore::ops::GetOpDef(op_name);
243     if (op_def == nullptr) {
244       MS_LOG(WARNING) << op_name << " not found in op def.";
245       continue;
246     }
247     auto cnode = dyn_cast<CNode>(node);
248     changed = Process(cnode, op_def, primitive) || changed;
249   }
250   if (changed) {
251     auto mng = GkUtils::GetFuncGraphManager(func_graph);
252     GkUtils::UpdateFuncGraphManager(mng, func_graph);
253   }
254   return changed;
255 }
256 
AddAttrToInput(const CNodePtr & cnode,const std::string & arg_name,const std::string & arg_handler,const PrimitivePtr & primitive,size_t pos)257 void ConvertGraphKernelToFrontEnd::AddAttrToInput(const CNodePtr &cnode, const std::string &arg_name,
258                                                   const std::string &arg_handler, const PrimitivePtr &primitive,
259                                                   size_t pos) {
260   auto value = primitive->GetAttr(arg_name);
261   if (!arg_handler.empty()) {
262     auto opp_arg_handler_func = GetOppArgHandlerFunc(arg_handler);
263     MS_EXCEPTION_IF_NULL(opp_arg_handler_func);
264     value = opp_arg_handler_func(value);
265   }
266   auto value_node = opt::CreateValueNodeWithKernelInfo(cnode->func_graph(), value);
267   auto inputs = cnode->inputs();
268   inputs.insert(inputs.begin() + pos, value_node);
269   cnode->set_inputs(inputs);
270   primitive->DelAttr(arg_name);
271 }
272 
ConvertInputsType(const CNodePtr & cnode,size_t idx,ops::OP_DTYPE fe_arg_type)273 bool ConvertGraphKernelToFrontEnd::ConvertInputsType(const CNodePtr &cnode, size_t idx, ops::OP_DTYPE fe_arg_type) {
274   // Only convert ValueNode(tensor with dtype int64_t) to ValueNode(Tuple of int64_t) now.
275   MS_EXCEPTION_IF_NULL(cnode);
276   auto input = cnode->input(idx);
277   MS_EXCEPTION_IF_NULL(input);
278   if (!input->isa<ValueNode>()) {
279     return false;
280   }
281 
282   auto origin_type = AnfAlgo::GetAbstractObjectType(input->abstract());
283   if (origin_type != kObjectTypeTensorType || fe_arg_type != ops::DT_TUPLE_INT) {
284     return false;
285   }
286 
287   auto value_opt = ops::GetArrayValue<int64_t>(input->cast<ValueNodePtr>()->value());
288   if (!value_opt.has_value()) {
289     return false;
290   }
291 
292   auto value_vec = value_opt.value().ToVector();
293   auto func_graph = cnode->func_graph();
294   auto new_input = opt::CreateValueNodeWithKernelInfo(func_graph, MakeValue<std::vector<int64_t>>(value_vec));
295   MS_LOG(DEBUG) << "Change [" << idx << "] input from " << input->DebugString() << " to " << new_input->DebugString()
296                 << " for " << cnode->fullname_with_scope();
297   cnode->set_input(idx, new_input);
298   return true;
299 }
300 
Process(const AnfNodePtr & node)301 bool ConvertGraphKernelToFrontEnd::Process(const AnfNodePtr &node) {
302   auto primitive = GetCNodePrimitive(node);
303   MS_EXCEPTION_IF_NULL(primitive);
304   const auto &op_name = primitive->name();
305   auto op_def = mindspore::ops::GetOpDef(op_name);
306   if (op_def == nullptr) {
307     MS_LOG(WARNING) << op_name << " not found in op def.";
308     return false;
309   }
310   const auto &op_def_args = op_def->args_;
311 
312   // 1. Convert attr to input.
313   auto cnode = node->cast<CNodePtr>();
314   MS_EXCEPTION_IF_NULL(cnode);
315   auto ori_input_size = AnfUtils::GetInputTensorNum(cnode);
316   if (ori_input_size > op_def_args.size()) {
317     MS_LOG(INFO) << node->fullname_with_scope() << " ori_input_size:" << ori_input_size << " > "
318                  << "op_def_args.size():" << op_def_args.size();
319   }
320 
321   std::vector<size_t> update_indices;
322   for (auto i = ori_input_size; i < op_def_args.size(); i++) {
323     // as_init_arg_ == 1 indicate the arg need convert
324     if (op_def_args[i].as_init_arg_ != 1) {
325       MS_LOG(EXCEPTION) << primitive->name() << "'s input:" << op_def_args[i].arg_name_
326                         << " must have as_init_arg_ when convert attr to input.";
327     }
328     MS_LOG(DEBUG) << cnode->DebugString() << " convert attr [" << op_def_args[i].arg_name_ << "] to input: " << i;
329     ConvertGraphKernelToFrontEnd::AddAttrToInput(cnode, op_def_args[i].arg_name_, op_def_args[i].arg_handler_,
330                                                  primitive, i + 1);
331     (void)update_indices.emplace_back(i + 1);
332   }
333 
334   // 2. Convert inputs type.
335   auto obj_map_iter = GetConvertKernelObjOps().find(op_name);
336   if (obj_map_iter != GetConvertKernelObjOps().end()) {
337     auto indices = obj_map_iter->second;
338     for (auto idx : indices) {
339       if (ConvertGraphKernelToFrontEnd::ConvertInputsType(cnode, idx, op_def_args[idx - 1].arg_dtype_)) {
340         (void)update_indices.emplace_back(idx);
341       }
342     }
343   }
344   bool changed = !update_indices.empty();
345   if (changed) {
346     auto cb = Callback::Instance();
347     MS_EXCEPTION_IF_NULL(cb);
348     cb->ResetKernelInfoInputs(cnode, update_indices);
349   }
350   return changed;
351 }
352 
Run(const FuncGraphPtr & func_graph)353 bool ConvertGraphKernelToFrontEnd::Run(const FuncGraphPtr &func_graph) {
354   bool changed = false;
355   MS_EXCEPTION_IF_NULL(func_graph);
356   MS_EXCEPTION_IF_NULL(func_graph->get_return());
357   auto todos = TopoSort(func_graph->get_return());
358   for (auto &node : todos) {
359     if (OpDefAdapter::NeedConvertGK2FE(node)) {
360       changed = ConvertGraphKernelToFrontEnd::Process(node) || changed;
361     }
362   }
363   if (changed) {
364     auto mng = GkUtils::GetFuncGraphManager(func_graph);
365     GkUtils::UpdateFuncGraphManager(mng, func_graph);
366   }
367   return changed;
368 }
369 
NeedConvertInputAndAttr(const AnfNodePtr & node)370 bool OpDefAdapter::NeedConvertInputAndAttr(const AnfNodePtr &node) {
371   return node->isa<CNode>() && GetConvertInputAttrOps().count(AnfUtils::GetCNodeName(node)) != 0;
372 }
373 
NeedConvertGK2FE(const AnfNodePtr & node)374 bool OpDefAdapter::NeedConvertGK2FE(const AnfNodePtr &node) {
375   auto cnode = node->cast<CNodePtr>();
376   if (cnode == nullptr) {
377     return false;
378   }
379   auto op_name = AnfUtils::GetCNodeName(node);
380   if (GetConvertInputAttrOps().count(op_name) > 0) {
381     return true;
382   }
383   auto obj_map_iter = GetConvertKernelObjOps().find(op_name);
384   if (obj_map_iter == GetConvertKernelObjOps().end()) {
385     return false;
386   }
387   auto &index = obj_map_iter->second;
388   // if the input type is tensor, it need to convert to the type (like tuple) that match OpDef.
389   for (auto idx : index) {
390     if (idx < cnode->size() && cnode->input(idx)->abstract()->GetShape()->isa<abstract::TensorShape>()) {
391       return true;
392     }
393   }
394   return false;
395 }
396 }  // namespace mindspore::graphkernel
397