1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/adapter/callback_impl.h"
17
18 #include <algorithm>
19 #include <vector>
20 #include <utility>
21 #include <memory>
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "utils/ms_context.h"
24 #include "include/backend/anf_runtime_algorithm.h"
25 #include "include/common/utils/anfalgo.h"
26 #include "kernel/common_utils.h"
27 #include "kernel/framework_utils.h"
28 #include "backend/common/graph_kernel/adapter/fake_abstract_shape.h"
29 #include "backend/common/graph_kernel/convert_input_and_attr.h"
30 #include "kernel/graph_kernel_info.h"
31 #include "backend/common/pass/insert_type_transform_op.h"
32 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
33
34 namespace mindspore::graphkernel {
35 namespace {
36 constexpr auto kPatternOpaque = "Opaque";
37
GetTypeIdForValueSequence(const ValueSequencePtr & value_sequence)38 TypeId GetTypeIdForValueSequence(const ValueSequencePtr &value_sequence) {
39 MS_EXCEPTION_IF_NULL(value_sequence);
40 const auto &element_values = value_sequence->value();
41 if (element_values.empty()) {
42 return kNumberTypeInt64;
43 }
44 const auto &first_element = element_values[0];
45 if (!first_element->isa<Scalar>()) {
46 MS_LOG(EXCEPTION) << "The value of " << value_sequence->ToString() << " is not a scalar.";
47 }
48 auto data_type = first_element->type();
49 MS_EXCEPTION_IF_NULL(data_type);
50 return data_type->type_id();
51 }
52
GetTypeAndFormats(const device::KernelWithIndex & kernel_with_index,std::vector<TypeId> * input_types,std::vector<std::string> * input_formats)53 void GetTypeAndFormats(const device::KernelWithIndex &kernel_with_index, std::vector<TypeId> *input_types,
54 std::vector<std::string> *input_formats) {
55 auto value_node = kernel_with_index.first->cast<ValueNodePtr>();
56 MS_EXCEPTION_IF_NULL(value_node);
57 auto value = value_node->value();
58 MS_EXCEPTION_IF_NULL(value);
59 if (value->isa<tensor::Tensor>()) {
60 auto tensor = value->cast<tensor::TensorPtr>();
61 MS_EXCEPTION_IF_NULL(tensor);
62 (void)input_types->emplace_back(tensor->data_type());
63 } else if (value->isa<ValueSequence>()) {
64 (void)input_types->emplace_back(GetTypeIdForValueSequence(value->cast<ValueSequencePtr>()));
65 } else if (value->isa<Scalar>()) {
66 auto scalar = value->cast<ScalarPtr>();
67 MS_EXCEPTION_IF_NULL(scalar);
68 auto data_type = scalar->type();
69 MS_EXCEPTION_IF_NULL(data_type);
70 (void)input_types->emplace_back(data_type->type_id());
71 } else {
72 MS_LOG(EXCEPTION) << "value " << value_node->ToString() << " is unexpected Type.";
73 }
74 (void)input_formats->emplace_back(kOpFormat_DEFAULT);
75 }
76 } // namespace
77
78 GRAPH_KERNEL_CALLBACK_REGISTER(CallbackImpl);
GetInputShape(const AnfNodePtr & node,size_t i)79 ShapeVector CallbackImpl::GetInputShape(const AnfNodePtr &node, size_t i) {
80 return AnfAlgo::GetInputDeviceShape(node, i);
81 }
82
GetOutputShape(const AnfNodePtr & node,size_t i)83 ShapeVector CallbackImpl::GetOutputShape(const AnfNodePtr &node, size_t i) {
84 return AnfAlgo::GetOutputDeviceShape(node, i);
85 }
86
GetInputInferShape(const AnfNodePtr & node,size_t i)87 ShapeVector CallbackImpl::GetInputInferShape(const AnfNodePtr &node, size_t i) {
88 return common::AnfAlgo::GetPrevNodeOutputInferShape(node, i);
89 }
90
GetOutputInferShape(const AnfNodePtr & node,size_t i)91 ShapeVector CallbackImpl::GetOutputInferShape(const AnfNodePtr &node, size_t i) {
92 return common::AnfAlgo::GetOutputInferShape(node, i);
93 }
94
GetInputType(const AnfNodePtr & node,size_t i)95 TypeId CallbackImpl::GetInputType(const AnfNodePtr &node, size_t i) { return AnfAlgo::GetInputDeviceDataType(node, i); }
96
GetOutputType(const AnfNodePtr & node,size_t i)97 TypeId CallbackImpl::GetOutputType(const AnfNodePtr &node, size_t i) {
98 return AnfAlgo::GetOutputDeviceDataType(node, i);
99 }
100
GetInputInferType(const AnfNodePtr & node,size_t i)101 TypeId CallbackImpl::GetInputInferType(const AnfNodePtr &node, size_t i) {
102 return common::AnfAlgo::GetPrevNodeOutputInferDataType(node, i);
103 }
104
GetOutputInferType(const AnfNodePtr & node,size_t i)105 TypeId CallbackImpl::GetOutputInferType(const AnfNodePtr &node, size_t i) {
106 return common::AnfAlgo::GetOutputInferDataType(node, i);
107 }
108
GetInputFormat(const AnfNodePtr & node,size_t i)109 std::string CallbackImpl::GetInputFormat(const AnfNodePtr &node, size_t i) { return AnfAlgo::GetInputFormat(node, i); }
110
GetOutputFormat(const AnfNodePtr & node,size_t i)111 std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) {
112 return AnfAlgo::GetOutputFormat(node, i);
113 }
114
GetProcessor(const AnfNodePtr & node)115 std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) {
116 auto processor = kernel::GetProcessorStr(node);
117 if (processor == kernel::kProcessorUnknown) {
118 // the processor will not be set during the Ascend kernel select, so it should be updated from context
119 processor = kernel::GetStrProcessorFromContext();
120 }
121 return processor;
122 }
123
GetTargetFromContextImpl(bool detail)124 std::string CallbackImpl::GetTargetFromContextImpl(bool detail) {
125 auto context_ptr = MsContext::GetInstance();
126 MS_EXCEPTION_IF_NULL(context_ptr);
127 const auto &target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
128 if (detail && target == kAscendDevice) {
129 return context_ptr->ascend_soc_name();
130 }
131 return target;
132 }
133
CollectInputTypesAndFormats(const AnfNodePtr & node,std::vector<TypeId> * input_types,std::vector<std::string> * input_formats,bool is_basic_node)134 void CallbackImpl::CollectInputTypesAndFormats(const AnfNodePtr &node, std::vector<TypeId> *input_types,
135 std::vector<std::string> *input_formats, bool is_basic_node) {
136 auto kernel_with_index = AnfUtils::VisitKernel(node, 0);
137 if (kernel_with_index.first->isa<ValueNode>()) {
138 GetTypeAndFormats(kernel_with_index, input_types, input_formats);
139 } else if (kernel_with_index.first->isa<Parameter>() && is_basic_node == false) {
140 (void)input_formats->emplace_back(kOpFormat_DEFAULT);
141 auto input_type = GetOutputInferType(kernel_with_index.first, kernel_with_index.second);
142 (void)input_types->emplace_back(input_type);
143 } else {
144 auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
145 (void)input_formats->emplace_back(std::move(input_format));
146 auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
147 (void)input_types->emplace_back(input_type);
148 }
149 }
150
SetGraphKernelNodeKernelInfo(const AnfNodePtr & node)151 void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) {
152 std::vector<std::string> graph_input_format;
153 std::vector<TypeId> graph_input_type;
154 std::vector<std::string> graph_output_format;
155 std::vector<TypeId> graph_output_type;
156 std::vector<kernel::KernelObjectType> graph_input_obj_type;
157 std::vector<kernel::KernelObjectType> graph_output_obj_type;
158 auto cnode = node->cast<CNodePtr>();
159 MS_EXCEPTION_IF_NULL(cnode);
160 auto fg = GetCNodeFuncGraph(node);
161 MS_EXCEPTION_IF_NULL(fg);
162 auto &inputs = cnode->inputs();
163 for (size_t i = 1; i < inputs.size(); ++i) {
164 CollectInputTypesAndFormats(inputs[i], &graph_input_type, &graph_input_format);
165 fg->parameters()[i - 1]->set_kernel_info(std::make_shared<device::KernelInfo>());
166 kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder;
167 para_info_builder.SetOutputsFormat({graph_input_format.back()});
168 para_info_builder.SetOutputsDeviceType({graph_input_type.back()});
169 para_info_builder.SetKernelType(KernelType::AKG_KERNEL);
170 para_info_builder.SetProcessor(kernel::GetProcessorFromContext());
171 AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), fg->parameters()[i - 1].get());
172 }
173 AnfNodePtrList outputs;
174 if (IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) {
175 auto fg_output = fg->output()->cast<CNodePtr>();
176 MS_EXCEPTION_IF_NULL(fg_output);
177 outputs.assign(fg_output->inputs().begin() + 1, fg_output->inputs().end());
178 } else {
179 outputs.push_back(fg->output());
180 }
181 for (size_t i = 0; i < outputs.size(); ++i) {
182 auto kernel_with_index = common::AnfAlgo::VisitKernel(outputs[i], 0);
183 graph_output_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
184 graph_output_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
185 }
186 opt::GenerateKernelObjectTypeForNewCNode(cnode, &graph_input_obj_type, &graph_output_obj_type);
187 kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
188 graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
189 graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
190 graph_info_builder.SetFusionType(kPatternOpaque);
191 graph_info_builder.SetInputsFormat(graph_input_format);
192 graph_info_builder.SetInputsDeviceType(graph_input_type);
193 graph_info_builder.SetOutputsFormat(graph_output_format);
194 graph_info_builder.SetOutputsDeviceType(graph_output_type);
195 graph_info_builder.SetInputsKernelObjectType(graph_input_obj_type);
196 graph_info_builder.SetOutputsKernelObjectType(graph_output_obj_type);
197 auto graph_selected_info = graph_info_builder.Build();
198 AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, node.get());
199 }
200
SetBasicNodeKernelInfo(const AnfNodePtr & node,const std::vector<inner::NodeBase> & outputs_info)201 void CallbackImpl::SetBasicNodeKernelInfo(const AnfNodePtr &node, const std::vector<inner::NodeBase> &outputs_info) {
202 node->set_kernel_info(std::make_shared<device::KernelInfo>());
203 std::vector<std::string> input_formats;
204 std::vector<TypeId> input_types;
205 auto cnode = node->cast<CNodePtr>();
206 if (cnode != nullptr) {
207 auto &inputs = cnode->inputs();
208 for (size_t i = 1; i < inputs.size(); ++i) {
209 CollectInputTypesAndFormats(inputs[i], &input_types, &input_formats, true);
210 }
211 }
212
213 std::vector<std::string> output_formats;
214 std::vector<TypeId> output_types;
215 AbstractBasePtrList abs_list;
216 bool has_fake_abstract = false;
217 for (size_t i = 0; i < outputs_info.size(); ++i) {
218 output_formats.push_back(outputs_info[i].format);
219 output_types.push_back(outputs_info[i].type);
220 ShapeVector abs_shape;
221 if (outputs_info[i].format != kOpFormat_DEFAULT) {
222 abs_shape = GetFakeAbstractShape(outputs_info[i].shape, outputs_info[i].format);
223 has_fake_abstract = true;
224 } else {
225 abs_shape = outputs_info[i].shape;
226 }
227 auto abs_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(outputs_info[i].type), abs_shape);
228 abs_list.push_back(abs_tensor);
229 }
230 if (has_fake_abstract) {
231 if (abs_list.size() == 1) {
232 node->set_abstract(abs_list[0]);
233 } else {
234 node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
235 }
236 }
237
238 kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
239 info_builder.SetInputsFormat(input_formats);
240 info_builder.SetInputsDeviceType(input_types);
241 info_builder.SetOutputsFormat(output_formats);
242 info_builder.SetOutputsDeviceType(output_types);
243 info_builder.SetProcessor(kernel::GetProcessorFromContext());
244 info_builder.SetKernelType(KernelType::AKG_KERNEL);
245 info_builder.SetFusionType(kPatternOpaque);
246 auto selected_info = info_builder.Build();
247 AnfAlgo::SetSelectKernelBuildInfo(selected_info, node.get());
248 }
249
ResetKernelInfoInputs(const AnfNodePtr & node,const std::vector<size_t> & indices)250 void CallbackImpl::ResetKernelInfoInputs(const AnfNodePtr &node, const std::vector<size_t> &indices) {
251 MS_EXCEPTION_IF_NULL(node);
252 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
253 if (kernel_info == nullptr) {
254 MS_LOG(DEBUG) << "KernelInfo do not exist for " << node->fullname_with_scope() << ", skip reset kernel info";
255 return;
256 }
257 auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
258 if (build_info == nullptr) {
259 MS_LOG(DEBUG) << "KernelBuildInfo do not exist for " << node->fullname_with_scope() << ", skip reset kernel info";
260 return;
261 }
262
263 std::vector<std::string> input_formats;
264 std::vector<TypeId> input_types;
265 std::vector<kernel::KernelObjectType> input_obj_type;
266 std::vector<kernel::KernelObjectType> output_obj_type;
267 auto cnode = node->cast<CNodePtr>();
268 if (cnode) {
269 auto orig_input_num = build_info->GetAllInputFormats().size();
270 auto &inputs = cnode->inputs();
271 std::vector<bool> visited(inputs.size(), false);
272 std::for_each(indices.begin(), indices.end(), [&visited](size_t index) { visited[index] = true; });
273 opt::GenerateKernelObjectTypeForNewCNode(cnode, &input_obj_type, &output_obj_type);
274 for (size_t i = 1; i < inputs.size(); ++i) {
275 if (visited[i]) {
276 CollectInputTypesAndFormats(inputs[i], &input_types, &input_formats, true);
277 } else {
278 auto input_idx = i - 1;
279 if (input_idx >= orig_input_num) {
280 MS_LOG(DEBUG) << "skip inputs[" << i << "] for node [" << node->fullname_with_scope() << "] "
281 << node->DebugString();
282 continue;
283 }
284 // reuse build info
285 input_types.emplace_back(build_info->GetInputDeviceType(input_idx));
286 input_formats.emplace_back(build_info->GetInputFormat(input_idx));
287 input_obj_type[input_idx] = build_info->GetInputKernelObjectType(input_idx);
288 }
289 }
290 }
291 auto input_num = AnfUtils::GetInputTensorNum(cnode);
292 if (input_formats.size() > input_num) {
293 input_formats.erase(input_formats.begin() + input_num, input_formats.end());
294 input_types.erase(input_types.begin() + input_num, input_types.end());
295 input_obj_type.erase(input_obj_type.begin() + input_num, input_obj_type.end());
296 }
297 build_info->SetInputsFormat(input_formats);
298 build_info->SetInputsDeviceType(input_types);
299 build_info->SetInputsKernelObjectType(input_obj_type);
300 }
301
SetEmptyKernelInfo(const AnfNodePtr & node)302 void CallbackImpl::SetEmptyKernelInfo(const AnfNodePtr &node) {
303 node->set_kernel_info(std::make_shared<device::KernelInfo>());
304 }
305
ResetKernelInfo(const AnfNodePtr & node)306 void CallbackImpl::ResetKernelInfo(const AnfNodePtr &node) {
307 MS_EXCEPTION_IF_NULL(node);
308 auto ori_cnode = node->cast<CNodePtr>();
309 MS_EXCEPTION_IF_NULL(ori_cnode);
310 CNodePtr cnode = ori_cnode;
311 bool need_convert = OpDefAdapter::NeedConvertGK2FE(cnode);
312 if (need_convert) {
313 // convert attr to input for selecting kernel, but not changed the original node.
314 // the original cnode will be modified in the pass ConvertGraphKernelToFrontEnd of postprocess.
315 cnode = node->func_graph()->NewCNode(ori_cnode->inputs());
316 cnode->CloneCNodeInfo(ori_cnode);
317 auto p = GetCNodePrimitive(ori_cnode);
318 MS_EXCEPTION_IF_NULL(p);
319 cnode->set_input(0, NewValueNode(p->Clone()));
320 cnode->input(0)->set_abstract(ori_cnode->input(0)->abstract());
321 cnode->input(0)->set_kernel_info(ori_cnode->input(0)->kernel_info_ptr());
322 need_convert = ConvertGraphKernelToFrontEnd::Process(cnode);
323 if (!need_convert) {
324 cnode = ori_cnode;
325 }
326 }
327 std::vector<std::string> ori_out_format;
328 if (IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
329 ori_out_format = AnfAlgo::GetAllOutputFormats(cnode);
330 if (std::all_of(ori_out_format.begin(), ori_out_format.end(),
331 [](const std::string &f) { return f == kOpFormat_DEFAULT; })) {
332 ori_out_format.clear();
333 }
334 }
335 if (GetTargetFromContext() == kAscendDevice) {
336 auto kernel_info = cnode->kernel_info_ptr();
337 if (kernel_info == nullptr) {
338 cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
339 }
340 auto kernel_info_setter = GraphKernelInfoManager::Instance().GetGraphKernelInfo(kAscendDevice);
341 MS_EXCEPTION_IF_NULL(kernel_info_setter);
342 kernel_info_setter->SetKernelInfo(cnode, KernelType::UNKNOWN_KERNEL_TYPE);
343 } else if (GetTargetFromContext() == kGPUDevice) {
344 cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
345 auto kernel_info_setter = GraphKernelInfoManager::Instance().GetGraphKernelInfo(kGPUDevice);
346 MS_EXCEPTION_IF_NULL(kernel_info_setter);
347 kernel_info_setter->SetKernelInfo(cnode, KernelType::UNKNOWN_KERNEL_TYPE);
348 } else {
349 cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
350 auto kernel_info_setter = GraphKernelInfoManager::Instance().GetGraphKernelInfo(kCPUDevice);
351 if (kernel_info_setter != nullptr) {
352 kernel_info_setter->SetKernelInfo(cnode, KernelType::UNKNOWN_KERNEL_TYPE);
353 }
354 }
355 if (!ori_out_format.empty()) {
356 auto kernel_info = dynamic_cast<device::KernelInfo *>(cnode->kernel_info());
357 MS_EXCEPTION_IF_NULL(kernel_info);
358 auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
359 MS_EXCEPTION_IF_NULL(build_info);
360 build_info->SetOutputsFormat(ori_out_format);
361 }
362 if (need_convert) {
363 ori_cnode->set_kernel_info(cnode->kernel_info_ptr());
364 std::vector<size_t> indices(ori_cnode->inputs().size());
365 std::iota(indices.begin(), indices.end(), kIndex0);
366 ResetKernelInfoInputs(ori_cnode, indices);
367 }
368 }
369
GetInputShape(const AnfNodePtr & node,size_t i)370 ShapeVector CallbackImplWithInferShape::GetInputShape(const AnfNodePtr &node, size_t i) {
371 return CallbackImpl::GetInputInferShape(node, i);
372 }
373
GetOutputShape(const AnfNodePtr & node,size_t i)374 ShapeVector CallbackImplWithInferShape::GetOutputShape(const AnfNodePtr &node, size_t i) {
375 return common::AnfAlgo::GetOutputInferShape(node, i);
376 }
377
GetInputType(const AnfNodePtr & node,size_t i)378 TypeId CallbackImplWithInferShape::GetInputType(const AnfNodePtr &node, size_t i) {
379 return CallbackImpl::GetInputInferType(node, i);
380 }
381
GetOutputType(const AnfNodePtr & node,size_t i)382 TypeId CallbackImplWithInferShape::GetOutputType(const AnfNodePtr &node, size_t i) {
383 return CallbackImpl::GetOutputInferType(node, i);
384 }
385
GetInputFormat(const AnfNodePtr &,size_t)386 std::string CallbackImplWithInferShape::GetInputFormat(const AnfNodePtr &, size_t) { return kOpFormat_DEFAULT; }
387
GetOutputFormat(const AnfNodePtr &,size_t)388 std::string CallbackImplWithInferShape::GetOutputFormat(const AnfNodePtr &, size_t) { return kOpFormat_DEFAULT; }
389
SetBasicNodeKernelInfo(const AnfNodePtr & node,const std::vector<inner::NodeBase> & outputs_info)390 void CallbackImplWithInferShape::SetBasicNodeKernelInfo(const AnfNodePtr &node,
391 const std::vector<inner::NodeBase> &outputs_info) {
392 node->set_kernel_info(std::make_shared<device::KernelInfo>());
393 if (node->cast<CNodePtr>() != nullptr) {
394 return;
395 }
396 bool has_fake_abstract = false;
397 std::vector<TypeId> output_types;
398 std::vector<std::string> output_formats;
399 AbstractBasePtrList abs_list;
400 for (size_t i = 0; i < outputs_info.size(); ++i) {
401 output_types.push_back(outputs_info[i].type);
402 output_formats.push_back(outputs_info[i].format);
403 ShapeVector abs_shape;
404 if (outputs_info[i].format != kOpFormat_DEFAULT) {
405 abs_shape = GetFakeAbstractShape(outputs_info[i].shape, outputs_info[i].format);
406 has_fake_abstract = true;
407 } else {
408 abs_shape = outputs_info[i].shape;
409 }
410 abs_list.push_back(std::make_shared<abstract::AbstractTensor>(TypeIdToType(outputs_info[i].type), abs_shape));
411 }
412 if (has_fake_abstract) {
413 if (abs_list.size() == 1) {
414 node->set_abstract(abs_list[0]);
415 } else {
416 node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
417 }
418 }
419
420 kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
421 info_builder.SetOutputsFormat(output_formats);
422 info_builder.SetOutputsDeviceType(output_types);
423 AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), node.get());
424 }
425
GetProcessor(const AnfNodePtr &)426 std::string CallbackImplWithInferShape::GetProcessor(const AnfNodePtr &) {
427 return kernel::GetStrProcessorFromContext();
428 }
429 } // namespace mindspore::graphkernel
430