1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/proactive_fallback_expander.h"
18
19 #include <unordered_set>
20 #include <vector>
21 #include <string>
22 #include <memory>
23
24 #include "ir/anf.h"
25 #include "utils/ms_context.h"
26 #include "kernel/graph_kernel_info.h"
27 #include "include/backend/anf_runtime_algorithm.h"
28 #include "backend/common/expander/fallback/fallback_irbuilder.h"
29
30 namespace mindspore::graphkernel {
GetFallbackOps()31 const std::unordered_set<std::string> &ProactiveFallbackExpander::GetFallbackOps() {
32 static const std::unordered_set<std::string> fallback_ops_list_ = {"AddExt", "SubExt"};
33 return fallback_ops_list_;
34 }
35
Run(const FuncGraphPtr & func_graph)36 bool ProactiveFallbackExpander::Run(const FuncGraphPtr &func_graph) {
37 MS_EXCEPTION_IF_NULL(func_graph);
38 MS_EXCEPTION_IF_NULL(func_graph->get_return());
39 auto mng = func_graph->manager();
40 MS_EXCEPTION_IF_NULL(mng);
41 auto nodes = TopoSort(func_graph->get_return());
42 const auto &need_fallback_ops = GetFallbackOps();
43 for (const auto &node : nodes) {
44 if (!node->isa<CNode>()) {
45 continue;
46 }
47 auto cnode = node->cast<CNodePtr>();
48 const std::string &prim_name = GetCNodePrimitive(cnode)->name();
49 if (need_fallback_ops.find(prim_name) == need_fallback_ops.end()) {
50 continue;
51 }
52 MS_LOG(DEBUG) << "Start Fallback node: " << cnode->fullname_with_scope();
53 auto func = [](const CNodePtr &cnode) -> bool {
54 MS_EXCEPTION_IF_NULL(cnode);
55 for (size_t i = 1; i < cnode->size(); i++) {
56 const auto &input = cnode->input(i);
57 if (!input->isa<ValueNode>()) {
58 continue;
59 }
60 auto input_kernel_info = input->kernel_info_ptr();
61 if (input_kernel_info == nullptr) {
62 input_kernel_info = std::make_shared<device::KernelInfo>();
63 input->set_kernel_info(input_kernel_info);
64 }
65 if (input_kernel_info->has_build_info()) {
66 continue;
67 }
68 auto info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
69 MS_EXCEPTION_IF_NULL(info_builder);
70 auto vnode = input->cast<ValueNodePtr>();
71 auto value = vnode->value();
72 MS_EXCEPTION_IF_NULL(value);
73 if (value->isa<tensor::Tensor>()) {
74 auto tensor = value->cast<tensor::TensorPtr>();
75 info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
76 info_builder->SetOutputsDeviceType(std::vector<TypeId>{tensor->Dtype()->type_id()});
77 } else if (value->isa<Scalar>()) {
78 auto scalar = value->cast<ScalarPtr>();
79 info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
80 info_builder->SetOutputsDeviceType(std::vector<TypeId>{scalar->type()->type_id()});
81 } else {
82 return false;
83 }
84 AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), input.get());
85 }
86 auto kernel_info = cnode->kernel_info_ptr();
87 if (kernel_info == nullptr) {
88 cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
89 }
90 auto kernel_info_setter = GraphKernelInfoManager::Instance().GetGraphKernelInfo(kAscendDevice);
91 MS_EXCEPTION_IF_NULL(kernel_info_setter);
92 kernel_info_setter->SetKernelInfo(cnode, KernelType::UNKNOWN_KERNEL_TYPE);
93 return true;
94 };
95 expander::FallbackIRBuilder ib(prim_name, cnode->func_graph(), func);
96 const auto *handle = expander::IRBuilderFactory::Instance().GetBuilder(prim_name);
97 if (handle == nullptr) {
98 MS_LOG(EXCEPTION) << "No fallback handle for node: " << cnode->fullname_with_scope();
99 return false;
100 }
101 auto output = ib.Run(cnode, *handle);
102 (void)mng->Replace(cnode, output);
103 }
104 return true;
105 }
106
107 } // namespace mindspore::graphkernel
108