• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/proactive_fallback_expander.h"
18 
19 #include <unordered_set>
20 #include <vector>
21 #include <string>
22 #include <memory>
23 
24 #include "ir/anf.h"
25 #include "utils/ms_context.h"
26 #include "kernel/graph_kernel_info.h"
27 #include "include/backend/anf_runtime_algorithm.h"
28 #include "backend/common/expander/fallback/fallback_irbuilder.h"
29 
30 namespace mindspore::graphkernel {
GetFallbackOps()31 const std::unordered_set<std::string> &ProactiveFallbackExpander::GetFallbackOps() {
32   static const std::unordered_set<std::string> fallback_ops_list_ = {"AddExt", "SubExt"};
33   return fallback_ops_list_;
34 }
35 
Run(const FuncGraphPtr & func_graph)36 bool ProactiveFallbackExpander::Run(const FuncGraphPtr &func_graph) {
37   MS_EXCEPTION_IF_NULL(func_graph);
38   MS_EXCEPTION_IF_NULL(func_graph->get_return());
39   auto mng = func_graph->manager();
40   MS_EXCEPTION_IF_NULL(mng);
41   auto nodes = TopoSort(func_graph->get_return());
42   const auto &need_fallback_ops = GetFallbackOps();
43   for (const auto &node : nodes) {
44     if (!node->isa<CNode>()) {
45       continue;
46     }
47     auto cnode = node->cast<CNodePtr>();
48     const std::string &prim_name = GetCNodePrimitive(cnode)->name();
49     if (need_fallback_ops.find(prim_name) == need_fallback_ops.end()) {
50       continue;
51     }
52     MS_LOG(DEBUG) << "Start Fallback node: " << cnode->fullname_with_scope();
53     auto func = [](const CNodePtr &cnode) -> bool {
54       MS_EXCEPTION_IF_NULL(cnode);
55       for (size_t i = 1; i < cnode->size(); i++) {
56         const auto &input = cnode->input(i);
57         if (!input->isa<ValueNode>()) {
58           continue;
59         }
60         auto input_kernel_info = input->kernel_info_ptr();
61         if (input_kernel_info == nullptr) {
62           input_kernel_info = std::make_shared<device::KernelInfo>();
63           input->set_kernel_info(input_kernel_info);
64         }
65         if (input_kernel_info->has_build_info()) {
66           continue;
67         }
68         auto info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
69         MS_EXCEPTION_IF_NULL(info_builder);
70         auto vnode = input->cast<ValueNodePtr>();
71         auto value = vnode->value();
72         MS_EXCEPTION_IF_NULL(value);
73         if (value->isa<tensor::Tensor>()) {
74           auto tensor = value->cast<tensor::TensorPtr>();
75           info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
76           info_builder->SetOutputsDeviceType(std::vector<TypeId>{tensor->Dtype()->type_id()});
77         } else if (value->isa<Scalar>()) {
78           auto scalar = value->cast<ScalarPtr>();
79           info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
80           info_builder->SetOutputsDeviceType(std::vector<TypeId>{scalar->type()->type_id()});
81         } else {
82           return false;
83         }
84         AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), input.get());
85       }
86       auto kernel_info = cnode->kernel_info_ptr();
87       if (kernel_info == nullptr) {
88         cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
89       }
90       auto kernel_info_setter = GraphKernelInfoManager::Instance().GetGraphKernelInfo(kAscendDevice);
91       MS_EXCEPTION_IF_NULL(kernel_info_setter);
92       kernel_info_setter->SetKernelInfo(cnode, KernelType::UNKNOWN_KERNEL_TYPE);
93       return true;
94     };
95     expander::FallbackIRBuilder ib(prim_name, cnode->func_graph(), func);
96     const auto *handle = expander::IRBuilderFactory::Instance().GetBuilder(prim_name);
97     if (handle == nullptr) {
98       MS_LOG(EXCEPTION) << "No fallback handle for node: " << cnode->fullname_with_scope();
99       return false;
100     }
101     auto output = ib.Run(cnode, *handle);
102     (void)mng->Replace(cnode, output);
103   }
104   return true;
105 }
106 
107 }  // namespace mindspore::graphkernel
108