1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <cstdint>
18 #include <functional>
19 #include <map>
20 #include <utility>
21 #include <vector>
22 #include <unordered_map>
23 #include "base/base.h"
24 #include "backend/common/graph_kernel/convert_input_and_attr.h"
25 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
26 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
27 #include "include/backend/anf_runtime_algorithm.h"
28 #include "include/backend/optimizer/helper.h"
29 #include "include/api/format.h"
30 #include "ops/auto_generate/gen_ops_primitive.h"
31 #include "ops/array_ops.h"
32 #include "ops/op_def.h"
33 #include "ops/op_utils.h"
34 #include "ops/sequence_ops.h"
35 #include "utils/anf_utils.h"
36 #include "utils/check_convert_utils.h"
37
38 namespace mindspore::graphkernel {
39 namespace {
GetConvertInputAttrOps()40 const std::set<std::string> &GetConvertInputAttrOps() {
41 static const std::set<std::string> convert_input_attr_ops = {
42 prim::kPrimSoftmax->name(), prim::kPrimReduceSum->name(), prim::kPrimReduceMax->name(),
43 prim::kPrimReduceMin->name(), prim::kPrimReduceMean->name(), prim::kPrimOneHot->name(),
44 prim::kPrimMinimumGrad->name(), prim::kPrimMaximumGrad->name(), prim::kPrimGather->name(),
45 prim::kPrimCumSum->name(), prim::kPrimArgmin->name(), prim::kPrimArgmax->name(),
46 prim::kPrimBiasAdd->name(), prim::kPrimBiasAddGrad->name(), prim::kPrimLayerNorm->name(),
47 prim::kPrimLayerNormGrad->name(), prim::kPrimLogSoftmax->name(), prim::kPrimLogSoftmaxGrad->name(),
48 prim::kPrimStridedSlice->name(), prim::kPrimAdamWeightDecay->name(), prim::kPrimMatMul->name(),
49 prim::kPrimBatchMatMul->name(),
50 };
51 return convert_input_attr_ops;
52 }
53
GetConvertKernelObjOps()54 const std::map<std::string, std::vector<size_t>> &GetConvertKernelObjOps() {
55 static const std::map<std::string, std::vector<size_t>> convert_kernel_obj_ops = {
56 {prim::kPrimReshape->name(), {2}},
57 {prim::kPrimReduceSum->name(), {2}}, // axis is tuple(int)
58 {prim::kPrimReduceMax->name(), {2}}, // axis is tuple(int)
59 {prim::kPrimReduceMin->name(), {2}}, // axis is tuple(int)
60 {prim::kPrimReduceMean->name(), {2}}, // axis is tuple(int)
61 {prim::kPrimStridedSlice->name(), {2, 3, 4}}, // begin, end, strides
62 {prim::kPrimTile->name(), {2}},
63 {prim::kPrimTranspose->name(), {2}},
64 };
65 return convert_kernel_obj_ops;
66 }
67
EnumToFormat(const ValuePtr & value)68 ValuePtr EnumToFormat(const ValuePtr &value) {
69 if (!value->isa<Int64Imm>()) {
70 MS_LOG(EXCEPTION) << value->ToString() << " is not Int64Imm.";
71 }
72 auto val = GetValue<int64_t>(value);
73 if (val == Format::NCHW) {
74 return MakeValue("NCHW");
75 } else if (val == Format::NHWC) {
76 return MakeValue("NHWC");
77 } else if (val == Format::NCDHW) {
78 return MakeValue("NCDHW");
79 } else {
80 MS_LOG(EXCEPTION) << value->ToString() << " is unexpected.";
81 }
82 }
83
FormatToEnum(const ValuePtr & value)84 ValuePtr FormatToEnum(const ValuePtr &value) {
85 auto format = GetValue<std::string>(value);
86 if (format == "NCHW") {
87 return MakeValue<int64_t>(Format::NCHW);
88 } else if (format == "NHWC") {
89 return MakeValue<int64_t>(Format::NHWC);
90 } else if (format == "NCDHW") {
91 return MakeValue<int64_t>(Format::NCDHW);
92 } else {
93 MS_LOG(EXCEPTION) << value->ToString() << " value:" << format << " is unexpected.";
94 }
95 }
96
EnumToDtype(const ValuePtr & value)97 ValuePtr EnumToDtype(const ValuePtr &value) {
98 if (!value->isa<Int64Imm>()) {
99 MS_LOG(EXCEPTION) << value->ToString() << " is not Int64Imm.";
100 }
101 auto val = GetValue<int64_t>(value);
102 return TypeIdToType(static_cast<TypeId>(val));
103 }
104
DtypeToEnum(const ValuePtr & value)105 ValuePtr DtypeToEnum(const ValuePtr &value) {
106 if (!value->isa<Type>()) {
107 MS_LOG(EXCEPTION) << value->ToString() << " is not Type.";
108 }
109 auto type_id = value->cast<TypePtr>()->type_id();
110 return MakeValue<int64_t>(type_id);
111 }
112
113 using ArgHandlerFunc = std::function<ValuePtr(const ValuePtr &)>;
114
GetArgHandlerFunc(const std::string & arg_handler)115 ArgHandlerFunc GetArgHandlerFunc(const std::string &arg_handler) {
116 static const std::unordered_map<std::string, ArgHandlerFunc> arg_handler_funcs = {
117 {"str_to_enum", EnumToFormat},
118 {"dtype_to_type_id", EnumToDtype},
119 };
120 if (arg_handler_funcs.find(arg_handler) != arg_handler_funcs.end()) {
121 return arg_handler_funcs.at(arg_handler);
122 } else {
123 return nullptr;
124 }
125 }
126
GetOppArgHandlerFunc(const std::string & arg_handler)127 ArgHandlerFunc GetOppArgHandlerFunc(const std::string &arg_handler) {
128 static const std::unordered_map<std::string, ArgHandlerFunc> opp_arg_handler_funcs = {
129 {"str_to_enum", FormatToEnum},
130 {"dtype_to_type_id", DtypeToEnum},
131 };
132 if (opp_arg_handler_funcs.find(arg_handler) != opp_arg_handler_funcs.end()) {
133 return opp_arg_handler_funcs.at(arg_handler);
134 } else {
135 return nullptr;
136 }
137 }
138 } // namespace
139
AddConstInputToAttr(const CNodePtr & cnode,const size_t input_index,const std::string & arg_name,const std::string & arg_handler,const PrimitivePtr & primitive)140 void ConvertFrontEndToGraphKernel::AddConstInputToAttr(const CNodePtr &cnode, const size_t input_index,
141 const std::string &arg_name, const std::string &arg_handler,
142 const PrimitivePtr &primitive) {
143 if (input_index >= cnode->size() - 1) {
144 MS_LOG(EXCEPTION) << "The index of args in op_def `" << input_index
145 << "` should less than the inputs size minus one `" << cnode->size() - 1 << "`.";
146 }
147 auto input_node = cnode->inputs()[input_index + 1];
148
149 ValuePtr value = nullptr;
150 if (input_node->isa<ValueNode>()) {
151 auto value_node = input_node->cast<ValueNodePtr>();
152 value = value_node->value();
153 } else if (input_node->isa<Parameter>()) {
154 auto parameter_node = input_node->cast<ParameterPtr>();
155 value = parameter_node->abstract()->BuildValue();
156 }
157 if (value == nullptr) {
158 MS_LOG(EXCEPTION) << cnode->ToString() << " is not Value.";
159 }
160 if (value->isa<ValueAny>()) {
161 MS_LOG(EXCEPTION) << cnode->ToString() << " is ValueAny.";
162 }
163 if (!arg_handler.empty() && !value->isa<None>()) {
164 auto arg_handler_func = GetArgHandlerFunc(arg_handler);
165 MS_EXCEPTION_IF_NULL(arg_handler_func);
166 value = arg_handler_func(value);
167 primitive->AddAttr(arg_name, value);
168 return;
169 }
170
171 if (!value->isa<tensor::Tensor>()) {
172 primitive->AddAttr(arg_name, value);
173 return;
174 }
175 auto value_vector = CheckAndConvertUtils::CheckTensorIntValue(arg_name, value, primitive->name());
176 auto tensor = value->cast<tensor::TensorPtr>();
177 auto tensor_shape = tensor->shape_c();
178 MS_LOG(DEBUG) << cnode->ToString() << " 's input[" << input_index << "] is tensor.";
179 if (tensor_shape.empty()) {
180 primitive->AddAttr(arg_name, MakeValue(value_vector[0]));
181 } else {
182 primitive->AddAttr(arg_name, MakeValue(value_vector));
183 }
184 }
185
Process(const CNodePtr & cnode,const ops::OpDefPtr & op_def,const PrimitivePtr & primitive)186 bool ConvertFrontEndToGraphKernel::Process(const CNodePtr &cnode, const ops::OpDefPtr &op_def,
187 const PrimitivePtr &primitive) {
188 const auto &op_def_args = op_def->args_;
189 const auto &op_def_indexes = op_def->indexes_;
190 bool changed = false;
191 auto ori_input_size = AnfUtils::GetInputTensorNum(cnode);
192 if (op_def_args.size() != ori_input_size) {
193 MS_LOG(EXCEPTION) << "The size of args in op_def `" << op_def->args_.size()
194 << "` should be equal to the inputs size minus one `" << ori_input_size << "`.";
195 }
196 auto iter = op_def_args.crbegin();
197 auto new_input_size = op_def_args.size();
198 for (; iter != op_def_args.crend(); ++iter, --new_input_size) {
199 // as_init_arg_ == 1 indicate the arg need convert, the arg need convert is at the tail of the list
200 if (iter->as_init_arg_ != 1) {
201 break;
202 }
203 const auto &arg_name = iter->arg_name_;
204 const auto &arg_handler = iter->arg_handler_;
205 MS_LOG(DEBUG) << cnode->ToString() << " convert input to attr: " << arg_name;
206 if (auto index_iter = op_def_indexes.find(arg_name); index_iter != op_def_indexes.end()) {
207 AddConstInputToAttr(cnode, index_iter->second, arg_name, arg_handler, primitive);
208 changed = true;
209 } else {
210 MS_LOG(EXCEPTION) << primitive->name() << " not found index of attr[" << arg_name << "] in op def indexes.";
211 }
212 }
213 auto inputs = cnode->inputs();
214 if (changed) {
215 // remainder args in op_def_arg is the size of new input args
216 AnfNodePtrList new_inputs(inputs.begin(), inputs.begin() + new_input_size + 1);
217 for (size_t i = ori_input_size; i < inputs.size() - 1; ++i) {
218 new_inputs.emplace_back(inputs[i + 1]);
219 }
220 cnode->set_inputs(new_inputs);
221 auto cb = Callback::Instance();
222 MS_EXCEPTION_IF_NULL(cb);
223 cb->ResetKernelInfoInputs(cnode, {});
224 }
225 return changed;
226 }
227
Run(const FuncGraphPtr & func_graph)228 bool ConvertFrontEndToGraphKernel::Run(const FuncGraphPtr &func_graph) {
229 bool changed = false;
230 MS_EXCEPTION_IF_NULL(func_graph);
231 MS_EXCEPTION_IF_NULL(func_graph->get_return());
232 auto todos = TopoSort(func_graph->get_return());
233 for (auto &node : todos) {
234 if (!OpDefAdapter::NeedConvertInputAndAttr(node)) {
235 continue;
236 }
237 auto primitive = GetCNodePrimitive(node);
238 if (primitive == nullptr) {
239 continue;
240 }
241 const auto &op_name = primitive->name();
242 auto op_def = mindspore::ops::GetOpDef(op_name);
243 if (op_def == nullptr) {
244 MS_LOG(WARNING) << op_name << " not found in op def.";
245 continue;
246 }
247 auto cnode = dyn_cast<CNode>(node);
248 changed = Process(cnode, op_def, primitive) || changed;
249 }
250 if (changed) {
251 auto mng = GkUtils::GetFuncGraphManager(func_graph);
252 GkUtils::UpdateFuncGraphManager(mng, func_graph);
253 }
254 return changed;
255 }
256
AddAttrToInput(const CNodePtr & cnode,const std::string & arg_name,const std::string & arg_handler,const PrimitivePtr & primitive,size_t pos)257 void ConvertGraphKernelToFrontEnd::AddAttrToInput(const CNodePtr &cnode, const std::string &arg_name,
258 const std::string &arg_handler, const PrimitivePtr &primitive,
259 size_t pos) {
260 auto value = primitive->GetAttr(arg_name);
261 if (!arg_handler.empty()) {
262 auto opp_arg_handler_func = GetOppArgHandlerFunc(arg_handler);
263 MS_EXCEPTION_IF_NULL(opp_arg_handler_func);
264 value = opp_arg_handler_func(value);
265 }
266 auto value_node = opt::CreateValueNodeWithKernelInfo(cnode->func_graph(), value);
267 auto inputs = cnode->inputs();
268 inputs.insert(inputs.begin() + pos, value_node);
269 cnode->set_inputs(inputs);
270 primitive->DelAttr(arg_name);
271 }
272
ConvertInputsType(const CNodePtr & cnode,size_t idx,ops::OP_DTYPE fe_arg_type)273 bool ConvertGraphKernelToFrontEnd::ConvertInputsType(const CNodePtr &cnode, size_t idx, ops::OP_DTYPE fe_arg_type) {
274 // Only convert ValueNode(tensor with dtype int64_t) to ValueNode(Tuple of int64_t) now.
275 MS_EXCEPTION_IF_NULL(cnode);
276 auto input = cnode->input(idx);
277 MS_EXCEPTION_IF_NULL(input);
278 if (!input->isa<ValueNode>()) {
279 return false;
280 }
281
282 auto origin_type = AnfAlgo::GetAbstractObjectType(input->abstract());
283 if (origin_type != kObjectTypeTensorType || fe_arg_type != ops::DT_TUPLE_INT) {
284 return false;
285 }
286
287 auto value_opt = ops::GetArrayValue<int64_t>(input->cast<ValueNodePtr>()->value());
288 if (!value_opt.has_value()) {
289 return false;
290 }
291
292 auto value_vec = value_opt.value().ToVector();
293 auto func_graph = cnode->func_graph();
294 auto new_input = opt::CreateValueNodeWithKernelInfo(func_graph, MakeValue<std::vector<int64_t>>(value_vec));
295 MS_LOG(DEBUG) << "Change [" << idx << "] input from " << input->DebugString() << " to " << new_input->DebugString()
296 << " for " << cnode->fullname_with_scope();
297 cnode->set_input(idx, new_input);
298 return true;
299 }
300
Process(const AnfNodePtr & node)301 bool ConvertGraphKernelToFrontEnd::Process(const AnfNodePtr &node) {
302 auto primitive = GetCNodePrimitive(node);
303 MS_EXCEPTION_IF_NULL(primitive);
304 const auto &op_name = primitive->name();
305 auto op_def = mindspore::ops::GetOpDef(op_name);
306 if (op_def == nullptr) {
307 MS_LOG(WARNING) << op_name << " not found in op def.";
308 return false;
309 }
310 const auto &op_def_args = op_def->args_;
311
312 // 1. Convert attr to input.
313 auto cnode = node->cast<CNodePtr>();
314 MS_EXCEPTION_IF_NULL(cnode);
315 auto ori_input_size = AnfUtils::GetInputTensorNum(cnode);
316 if (ori_input_size > op_def_args.size()) {
317 MS_LOG(INFO) << node->fullname_with_scope() << " ori_input_size:" << ori_input_size << " > "
318 << "op_def_args.size():" << op_def_args.size();
319 }
320
321 std::vector<size_t> update_indices;
322 for (auto i = ori_input_size; i < op_def_args.size(); i++) {
323 // as_init_arg_ == 1 indicate the arg need convert
324 if (op_def_args[i].as_init_arg_ != 1) {
325 MS_LOG(EXCEPTION) << primitive->name() << "'s input:" << op_def_args[i].arg_name_
326 << " must have as_init_arg_ when convert attr to input.";
327 }
328 MS_LOG(DEBUG) << cnode->DebugString() << " convert attr [" << op_def_args[i].arg_name_ << "] to input: " << i;
329 ConvertGraphKernelToFrontEnd::AddAttrToInput(cnode, op_def_args[i].arg_name_, op_def_args[i].arg_handler_,
330 primitive, i + 1);
331 (void)update_indices.emplace_back(i + 1);
332 }
333
334 // 2. Convert inputs type.
335 auto obj_map_iter = GetConvertKernelObjOps().find(op_name);
336 if (obj_map_iter != GetConvertKernelObjOps().end()) {
337 auto indices = obj_map_iter->second;
338 for (auto idx : indices) {
339 if (ConvertGraphKernelToFrontEnd::ConvertInputsType(cnode, idx, op_def_args[idx - 1].arg_dtype_)) {
340 (void)update_indices.emplace_back(idx);
341 }
342 }
343 }
344 bool changed = !update_indices.empty();
345 if (changed) {
346 auto cb = Callback::Instance();
347 MS_EXCEPTION_IF_NULL(cb);
348 cb->ResetKernelInfoInputs(cnode, update_indices);
349 }
350 return changed;
351 }
352
Run(const FuncGraphPtr & func_graph)353 bool ConvertGraphKernelToFrontEnd::Run(const FuncGraphPtr &func_graph) {
354 bool changed = false;
355 MS_EXCEPTION_IF_NULL(func_graph);
356 MS_EXCEPTION_IF_NULL(func_graph->get_return());
357 auto todos = TopoSort(func_graph->get_return());
358 for (auto &node : todos) {
359 if (OpDefAdapter::NeedConvertGK2FE(node)) {
360 changed = ConvertGraphKernelToFrontEnd::Process(node) || changed;
361 }
362 }
363 if (changed) {
364 auto mng = GkUtils::GetFuncGraphManager(func_graph);
365 GkUtils::UpdateFuncGraphManager(mng, func_graph);
366 }
367 return changed;
368 }
369
NeedConvertInputAndAttr(const AnfNodePtr & node)370 bool OpDefAdapter::NeedConvertInputAndAttr(const AnfNodePtr &node) {
371 return node->isa<CNode>() && GetConvertInputAttrOps().count(AnfUtils::GetCNodeName(node)) != 0;
372 }
373
NeedConvertGK2FE(const AnfNodePtr & node)374 bool OpDefAdapter::NeedConvertGK2FE(const AnfNodePtr &node) {
375 auto cnode = node->cast<CNodePtr>();
376 if (cnode == nullptr) {
377 return false;
378 }
379 auto op_name = AnfUtils::GetCNodeName(node);
380 if (GetConvertInputAttrOps().count(op_name) > 0) {
381 return true;
382 }
383 auto obj_map_iter = GetConvertKernelObjOps().find(op_name);
384 if (obj_map_iter == GetConvertKernelObjOps().end()) {
385 return false;
386 }
387 auto &index = obj_map_iter->second;
388 // if the input type is tensor, it need to convert to the type (like tuple) that match OpDef.
389 for (auto idx : index) {
390 if (idx < cnode->size() && cnode->input(idx)->abstract()->GetShape()->isa<abstract::TensorShape>()) {
391 return true;
392 }
393 }
394 return false;
395 }
396 } // namespace mindspore::graphkernel
397