1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/pi/graph_compiler/inliner/func_inliner.h"
18 #include <memory>
19 #include <string>
20 #include <algorithm>
21 #include "pipeline/jit/pi/graph_compiler/parser/byte_code_parser.h"
22 #include "pybind11/pytypes.h"
23 #include "utils/log_adapter.h"
24
25 namespace mindspore {
26 namespace pijit {
27 namespace {
GetPyFunction(const ir::NodePtr & node)28 py::object GetPyFunction(const ir::NodePtr &node) {
29 if (node->isa<ir::Value>()) {
30 auto value = node->cast<ir::ValuePtr>()->GetValue();
31 if (py::isinstance<py::function>(value)) {
32 return value;
33 }
34 } else {
35 MS_EXCEPTION_IF_CHECK_FAIL(node->isa<ir::LoadValueNode>(), "Expected load a function.");
36 auto load = node->cast<ir::LoadValueNodePtr>();
37 if (load->GetOpCode() == LOAD_GLOBAL) {
38 return GetPyFunction(load->GetArgs().back());
39 }
40 }
41 return py::none();
42 }
43 } // namespace
44
Run()45 void FuncInlineDetector::Run() { Visit(func_); }
46
Visit_(const ir::FunctionNodePtr & node)47 void FuncInlineDetector::Visit_(const ir::FunctionNodePtr &node) {
48 std::for_each(node->GetNodes().begin(), node->GetNodes().end(), [this](const ir::NodePtr &node) {
49 cur_root_node_ = node;
50 Visit(node);
51 index_++;
52 });
53 }
54
55 // Input form : {BUILD_TUPLE, (pos_args)}
56 // {BUILD_LIST, (pos_args)}
57 // {LIST_EXTEND, ([pos_args], (vargs))}
58 // {LIST_TO_TUPLE, [pos_args, vargs]}
UnpackArgShell(const ir::NodePtr & arg)59 const ir::NodePtrList &UnpackArgShell(const ir::NodePtr &arg) {
60 MS_EXCEPTION_IF_CHECK_FAIL(arg->isa<ir::Operation>(), "Arg should be a operation.");
61 const auto op = arg->cast<ir::OperationPtr>();
62 bool is_expected = op->GetOpCode() == BUILD_LIST || op->GetOpCode() == BUILD_TUPLE ||
63 op->GetOpCode() == LIST_TO_TUPLE || op->GetOpCode() == LIST_EXTEND;
64 MS_EXCEPTION_IF_CHECK_FAIL(is_expected, "Not expected operation.");
65 if (arg->isa<ir::CastNode>()) {
66 return UnpackArgShell(op->GetArg());
67 }
68 return op->GetArgs();
69 }
70
71 // Input form : {DICT_MERGE, ({}, kwargs)}
UnpackKwargsShell(const ir::NodePtr & kwargs)72 const ir::NodePtr &UnpackKwargsShell(const ir::NodePtr &kwargs) {
73 if (!kwargs->isa<ir::UpdateNode>()) {
74 return kwargs;
75 }
76 auto node = kwargs->cast<ir::UpdateNodePtr>();
77 bool is_valid = node->GetArg()->isa<ir::BuildNode>() && node->GetArg()->cast<ir::BuildNodePtr>()->GetArgsCnt() == 0;
78 MS_EXCEPTION_IF_CHECK_FAIL(is_valid, "First arg should be a empty build node.");
79 return node->GetArg(1);
80 }
81
82 // Input form : {}
83 // {(pos_args)}
84 // {varargs}
85 // {kwargs}
86 // {((pos_args), vargs)}
87 // {(pos_args), kwargs}
88 // {vargs, kwargs}
89 // {((pos_args), vargs), kwargs}
90 // Output form : {pos_args..., vargs, kwargs}
UnpackArgsInTuple(const py::object & func,ir::NodePtrList * args)91 void UnpackArgsInTuple(const py::object &func, ir::NodePtrList *args) {
92 MS_EXCEPTION_IF_CHECK_FAIL(py::isinstance<py::function>(func), "Should be a function object.");
93 const auto code = reinterpret_cast<const PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
94 if ((IntToSize(code->co_flags) & CO_VARKEYWORDS) != 0) {
95 args->back() = UnpackKwargsShell(args->back());
96 }
97 if (code->co_argcount == 0) {
98 return;
99 }
100 const auto &inner_args = UnpackArgShell(args->front());
101 args->erase(args->begin());
102 args->insert(args->begin(), inner_args.begin(), inner_args.end());
103 if ((IntToSize(code->co_flags) & CO_VARARGS) != 0) {
104 const auto &pos_args = UnpackArgShell(args->front());
105 args->erase(args->begin());
106 args->insert(args->begin(), pos_args.begin(), pos_args.end());
107 }
108 }
109
Visit_(const ir::CallNodePtr & node)110 void FuncInlineDetector::Visit_(const ir::CallNodePtr &node) {
111 auto arg = node->GetArg(0);
112 if (!CanBeInlined(arg)) {
113 std::for_each(node->GetArgs().begin(), node->GetArgs().end(), [this](const ir::NodePtr &node) { Visit(node); });
114 } else {
115 const py::object func = GetPyFunction(arg);
116 auto byteCodeParser = std::make_shared<ByteCodeParser>(func);
117 ir::FunctionNodePtr func_node = byteCodeParser->Parse();
118 ir::NodePtrList args(node->GetArgs().begin() + 1, node->GetArgs().end());
119 if (node->GetOpCode() != CALL_FUNCTION) {
120 UnpackArgsInTuple(func, &args);
121 }
122 EvolvingFunction(func_node, args);
123 node->SetArg(0, func_node);
124 node_2_index_[node] = index_;
125 node_2_root_[node] = cur_root_node_;
126 std::for_each(node->GetArgs().begin() + 1, node->GetArgs().end(), [this](const ir::NodePtr &node) { Visit(node); });
127 }
128 }
129
GetRootNodeIndex(const ir::CallNodePtr & node) const130 size_t FuncInlineDetector::GetRootNodeIndex(const ir::CallNodePtr &node) const {
131 MS_EXCEPTION_IF_CHECK_FAIL(node_2_index_.find(node) != node_2_index_.end(),
132 "Invalid Call Node %" + std::to_string(node->GetNodeId()) + ".");
133 return node_2_index_.at(node);
134 }
135
GetRootNode(const ir::CallNodePtr & node) const136 const ir::NodePtr &FuncInlineDetector::GetRootNode(const ir::CallNodePtr &node) const {
137 MS_EXCEPTION_IF_CHECK_FAIL(node_2_root_.find(node) != node_2_root_.end(),
138 "Invalid Call Node %" + std::to_string(node->GetNodeId()) + ".");
139 return node_2_root_.at(node);
140 }
141
CanBeInlined(const ir::NodePtr & node) const142 bool FuncInlineDetector::CanBeInlined(const ir::NodePtr &node) const {
143 if (!node->isa<ir::Value>() && !node->isa<ir::LoadValueNode>()) {
144 return false;
145 }
146 auto func = GetPyFunction(node);
147 return !py::isinstance<py::none>(func) && PyFunction_Check(func.ptr());
148 }
149
EvolvingFunction(const ir::FunctionNodePtr & func_node,const ir::NodePtrList args) const150 void FuncInlineDetector::EvolvingFunction(const ir::FunctionNodePtr &func_node, const ir::NodePtrList args) const {
151 // Rename the local variables of the function to avoid variable name conflicts after inlining
152 auto renamer = std::make_shared<FuncLocalVarRenamer>(func_node);
153 renamer->Run();
154 // Eliminate parameters
155 auto eliminator = std::make_shared<FuncParameterEliminator>(func_node, args);
156 eliminator->Run();
157 }
158
Run()159 void FuncLocalVarRenamer::Run() { Visit(func_); }
160
Visit_(const ir::ParameterPtr & node)161 void FuncLocalVarRenamer::Visit_(const ir::ParameterPtr &node) {
162 node->SetName(func_->GetName() + "_" + node->GetName());
163 }
164
Visit_(const ir::ValuePtr & node)165 void FuncLocalVarRenamer::Visit_(const ir::ValuePtr &node) {
166 if (node->GetScope() == ir::kScopeLocal) {
167 auto name = func_->GetName() + "_" + node->GetName();
168 node->SetValue(py::str(name));
169 node->SetName(name);
170 }
171 }
172
Run()173 void FuncParameterEliminator::Run() { Mutate(func_); }
174
Mutate_(const ir::ParameterPtr & node)175 ir::NodePtr FuncParameterEliminator::Mutate_(const ir::ParameterPtr &node) {
176 if (node->GetIndex() < args_.size()) {
177 name_2_node_[node->GetName()] = args_[node->GetIndex()];
178 } else {
179 name_2_node_[node->GetName()] = node->GetDefaultValue();
180 }
181 return node;
182 }
183
Mutate_(const ir::LoadValueNodePtr & node)184 ir::NodePtr FuncParameterEliminator::Mutate_(const ir::LoadValueNodePtr &node) {
185 MS_EXCEPTION_IF_CHECK_FAIL(node->GetArg()->isa<ir::Value>(), "Expected a local var name.");
186 auto name = node->GetArg()->cast<ir::ValuePtr>()->GetName();
187 if (name_2_node_.find(name) != name_2_node_.end()) {
188 return name_2_node_.at(name);
189 }
190 return node;
191 }
192
Mutate_(const ir::StoreNodePtr & node)193 ir::NodePtr FuncParameterEliminator::Mutate_(const ir::StoreNodePtr &node) {
194 node->SetLeftArg(Mutate(node->GetLeftArg()));
195 auto target = node->GetRightArg();
196 if (node->GetOpCode() != STORE_FAST) {
197 node->SetRightArg(Mutate(node->GetRightArg()));
198 return node;
199 }
200 MS_EXCEPTION_IF_CHECK_FAIL(target->isa<ir::Value>(), "Expected a local var name.");
201 auto name = py::cast<std::string>(target->cast<ir::ValuePtr>()->GetValue());
202 name_2_node_.erase(name);
203 return node;
204 }
205
Run()206 void FuncInliner::Run() {
207 detector_->Run();
208 Mutate(func_);
209 InsertSubFunction();
210 }
211
InsertSubFunction()212 void FuncInliner::InsertSubFunction() {
213 ir::NodePtrList &roots = func_->GetNodes();
214 for (auto &[index, func_node] : index_2_function_) {
215 size_t idx = index + inserted_nodes_cnt_;
216 roots.insert(roots.begin() + idx, func_node->GetNodes().begin(), func_node->GetNodes().end() - 1);
217 inserted_nodes_cnt_ += func_node->GetNodes().size() - 1;
218 }
219 }
220
Mutate_(const ir::CallNodePtr & node)221 ir::NodePtr FuncInliner::Mutate_(const ir::CallNodePtr &node) {
222 auto func = node->GetArg(0);
223 if (!func->isa<ir::FunctionNode>()) {
224 return node;
225 }
226 auto func_node = func->cast<ir::FunctionNodePtr>();
227 size_t index = detector_->GetRootNodeIndex(node) + inserted_nodes_cnt_;
228 auto root = *(func_->GetNodes().begin() + index);
229 MS_EXCEPTION_IF_CHECK_FAIL(root == detector_->GetRootNode(node), "Detector index error.");
230 index_2_function_[index] = func_node;
231 auto ret = func_node->GetNodes().back();
232 MS_EXCEPTION_IF_CHECK_FAIL(ret->isa<ir::ReturnNode>(), "Excepted Return Node, but got " + ret->GetNodeName() + ".");
233 return ret->cast<ir::ReturnNodePtr>()->GetReturn();
234 }
235 } // namespace pijit
236 } // namespace mindspore
237