• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/pi/graph_compiler/inliner/func_inliner.h"
18 #include <memory>
19 #include <string>
20 #include <algorithm>
21 #include "pipeline/jit/pi/graph_compiler/parser/byte_code_parser.h"
22 #include "pybind11/pytypes.h"
23 #include "utils/log_adapter.h"
24 
25 namespace mindspore {
26 namespace pijit {
27 namespace {
GetPyFunction(const ir::NodePtr & node)28 py::object GetPyFunction(const ir::NodePtr &node) {
29   if (node->isa<ir::Value>()) {
30     auto value = node->cast<ir::ValuePtr>()->GetValue();
31     if (py::isinstance<py::function>(value)) {
32       return value;
33     }
34   } else {
35     MS_EXCEPTION_IF_CHECK_FAIL(node->isa<ir::LoadValueNode>(), "Expected load a function.");
36     auto load = node->cast<ir::LoadValueNodePtr>();
37     if (load->GetOpCode() == LOAD_GLOBAL) {
38       return GetPyFunction(load->GetArgs().back());
39     }
40   }
41   return py::none();
42 }
43 }  // namespace
44 
Run()45 void FuncInlineDetector::Run() { Visit(func_); }
46 
Visit_(const ir::FunctionNodePtr & node)47 void FuncInlineDetector::Visit_(const ir::FunctionNodePtr &node) {
48   std::for_each(node->GetNodes().begin(), node->GetNodes().end(), [this](const ir::NodePtr &node) {
49     cur_root_node_ = node;
50     Visit(node);
51     index_++;
52   });
53 }
54 
55 // Input form : {BUILD_TUPLE, (pos_args)}
56 //              {BUILD_LIST, (pos_args)}
57 //              {LIST_EXTEND, ([pos_args], (vargs))}
58 //              {LIST_TO_TUPLE, [pos_args, vargs]}
UnpackArgShell(const ir::NodePtr & arg)59 const ir::NodePtrList &UnpackArgShell(const ir::NodePtr &arg) {
60   MS_EXCEPTION_IF_CHECK_FAIL(arg->isa<ir::Operation>(), "Arg should be a operation.");
61   const auto op = arg->cast<ir::OperationPtr>();
62   bool is_expected = op->GetOpCode() == BUILD_LIST || op->GetOpCode() == BUILD_TUPLE ||
63                      op->GetOpCode() == LIST_TO_TUPLE || op->GetOpCode() == LIST_EXTEND;
64   MS_EXCEPTION_IF_CHECK_FAIL(is_expected, "Not expected operation.");
65   if (arg->isa<ir::CastNode>()) {
66     return UnpackArgShell(op->GetArg());
67   }
68   return op->GetArgs();
69 }
70 
71 // Input form : {DICT_MERGE, ({}, kwargs)}
UnpackKwargsShell(const ir::NodePtr & kwargs)72 const ir::NodePtr &UnpackKwargsShell(const ir::NodePtr &kwargs) {
73   if (!kwargs->isa<ir::UpdateNode>()) {
74     return kwargs;
75   }
76   auto node = kwargs->cast<ir::UpdateNodePtr>();
77   bool is_valid = node->GetArg()->isa<ir::BuildNode>() && node->GetArg()->cast<ir::BuildNodePtr>()->GetArgsCnt() == 0;
78   MS_EXCEPTION_IF_CHECK_FAIL(is_valid, "First arg should be a empty build node.");
79   return node->GetArg(1);
80 }
81 
82 // Input form : {}
83 //              {(pos_args)}
84 //              {varargs}
85 //              {kwargs}
86 //              {((pos_args), vargs)}
87 //              {(pos_args), kwargs}
88 //              {vargs, kwargs}
89 //              {((pos_args), vargs), kwargs}
90 // Output form : {pos_args..., vargs, kwargs}
UnpackArgsInTuple(const py::object & func,ir::NodePtrList * args)91 void UnpackArgsInTuple(const py::object &func, ir::NodePtrList *args) {
92   MS_EXCEPTION_IF_CHECK_FAIL(py::isinstance<py::function>(func), "Should be a function object.");
93   const auto code = reinterpret_cast<const PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
94   if ((IntToSize(code->co_flags) & CO_VARKEYWORDS) != 0) {
95     args->back() = UnpackKwargsShell(args->back());
96   }
97   if (code->co_argcount == 0) {
98     return;
99   }
100   const auto &inner_args = UnpackArgShell(args->front());
101   args->erase(args->begin());
102   args->insert(args->begin(), inner_args.begin(), inner_args.end());
103   if ((IntToSize(code->co_flags) & CO_VARARGS) != 0) {
104     const auto &pos_args = UnpackArgShell(args->front());
105     args->erase(args->begin());
106     args->insert(args->begin(), pos_args.begin(), pos_args.end());
107   }
108 }
109 
Visit_(const ir::CallNodePtr & node)110 void FuncInlineDetector::Visit_(const ir::CallNodePtr &node) {
111   auto arg = node->GetArg(0);
112   if (!CanBeInlined(arg)) {
113     std::for_each(node->GetArgs().begin(), node->GetArgs().end(), [this](const ir::NodePtr &node) { Visit(node); });
114   } else {
115     const py::object func = GetPyFunction(arg);
116     auto byteCodeParser = std::make_shared<ByteCodeParser>(func);
117     ir::FunctionNodePtr func_node = byteCodeParser->Parse();
118     ir::NodePtrList args(node->GetArgs().begin() + 1, node->GetArgs().end());
119     if (node->GetOpCode() != CALL_FUNCTION) {
120       UnpackArgsInTuple(func, &args);
121     }
122     EvolvingFunction(func_node, args);
123     node->SetArg(0, func_node);
124     node_2_index_[node] = index_;
125     node_2_root_[node] = cur_root_node_;
126     std::for_each(node->GetArgs().begin() + 1, node->GetArgs().end(), [this](const ir::NodePtr &node) { Visit(node); });
127   }
128 }
129 
GetRootNodeIndex(const ir::CallNodePtr & node) const130 size_t FuncInlineDetector::GetRootNodeIndex(const ir::CallNodePtr &node) const {
131   MS_EXCEPTION_IF_CHECK_FAIL(node_2_index_.find(node) != node_2_index_.end(),
132                              "Invalid Call Node %" + std::to_string(node->GetNodeId()) + ".");
133   return node_2_index_.at(node);
134 }
135 
GetRootNode(const ir::CallNodePtr & node) const136 const ir::NodePtr &FuncInlineDetector::GetRootNode(const ir::CallNodePtr &node) const {
137   MS_EXCEPTION_IF_CHECK_FAIL(node_2_root_.find(node) != node_2_root_.end(),
138                              "Invalid Call Node %" + std::to_string(node->GetNodeId()) + ".");
139   return node_2_root_.at(node);
140 }
141 
CanBeInlined(const ir::NodePtr & node) const142 bool FuncInlineDetector::CanBeInlined(const ir::NodePtr &node) const {
143   if (!node->isa<ir::Value>() && !node->isa<ir::LoadValueNode>()) {
144     return false;
145   }
146   auto func = GetPyFunction(node);
147   return !py::isinstance<py::none>(func) && PyFunction_Check(func.ptr());
148 }
149 
EvolvingFunction(const ir::FunctionNodePtr & func_node,const ir::NodePtrList args) const150 void FuncInlineDetector::EvolvingFunction(const ir::FunctionNodePtr &func_node, const ir::NodePtrList args) const {
151   // Rename the local variables of the function to avoid variable name conflicts after inlining
152   auto renamer = std::make_shared<FuncLocalVarRenamer>(func_node);
153   renamer->Run();
154   // Eliminate parameters
155   auto eliminator = std::make_shared<FuncParameterEliminator>(func_node, args);
156   eliminator->Run();
157 }
158 
Run()159 void FuncLocalVarRenamer::Run() { Visit(func_); }
160 
Visit_(const ir::ParameterPtr & node)161 void FuncLocalVarRenamer::Visit_(const ir::ParameterPtr &node) {
162   node->SetName(func_->GetName() + "_" + node->GetName());
163 }
164 
Visit_(const ir::ValuePtr & node)165 void FuncLocalVarRenamer::Visit_(const ir::ValuePtr &node) {
166   if (node->GetScope() == ir::kScopeLocal) {
167     auto name = func_->GetName() + "_" + node->GetName();
168     node->SetValue(py::str(name));
169     node->SetName(name);
170   }
171 }
172 
Run()173 void FuncParameterEliminator::Run() { Mutate(func_); }
174 
Mutate_(const ir::ParameterPtr & node)175 ir::NodePtr FuncParameterEliminator::Mutate_(const ir::ParameterPtr &node) {
176   if (node->GetIndex() < args_.size()) {
177     name_2_node_[node->GetName()] = args_[node->GetIndex()];
178   } else {
179     name_2_node_[node->GetName()] = node->GetDefaultValue();
180   }
181   return node;
182 }
183 
Mutate_(const ir::LoadValueNodePtr & node)184 ir::NodePtr FuncParameterEliminator::Mutate_(const ir::LoadValueNodePtr &node) {
185   MS_EXCEPTION_IF_CHECK_FAIL(node->GetArg()->isa<ir::Value>(), "Expected a local var name.");
186   auto name = node->GetArg()->cast<ir::ValuePtr>()->GetName();
187   if (name_2_node_.find(name) != name_2_node_.end()) {
188     return name_2_node_.at(name);
189   }
190   return node;
191 }
192 
Mutate_(const ir::StoreNodePtr & node)193 ir::NodePtr FuncParameterEliminator::Mutate_(const ir::StoreNodePtr &node) {
194   node->SetLeftArg(Mutate(node->GetLeftArg()));
195   auto target = node->GetRightArg();
196   if (node->GetOpCode() != STORE_FAST) {
197     node->SetRightArg(Mutate(node->GetRightArg()));
198     return node;
199   }
200   MS_EXCEPTION_IF_CHECK_FAIL(target->isa<ir::Value>(), "Expected a local var name.");
201   auto name = py::cast<std::string>(target->cast<ir::ValuePtr>()->GetValue());
202   name_2_node_.erase(name);
203   return node;
204 }
205 
Run()206 void FuncInliner::Run() {
207   detector_->Run();
208   Mutate(func_);
209   InsertSubFunction();
210 }
211 
InsertSubFunction()212 void FuncInliner::InsertSubFunction() {
213   ir::NodePtrList &roots = func_->GetNodes();
214   for (auto &[index, func_node] : index_2_function_) {
215     size_t idx = index + inserted_nodes_cnt_;
216     roots.insert(roots.begin() + idx, func_node->GetNodes().begin(), func_node->GetNodes().end() - 1);
217     inserted_nodes_cnt_ += func_node->GetNodes().size() - 1;
218   }
219 }
220 
Mutate_(const ir::CallNodePtr & node)221 ir::NodePtr FuncInliner::Mutate_(const ir::CallNodePtr &node) {
222   auto func = node->GetArg(0);
223   if (!func->isa<ir::FunctionNode>()) {
224     return node;
225   }
226   auto func_node = func->cast<ir::FunctionNodePtr>();
227   size_t index = detector_->GetRootNodeIndex(node) + inserted_nodes_cnt_;
228   auto root = *(func_->GetNodes().begin() + index);
229   MS_EXCEPTION_IF_CHECK_FAIL(root == detector_->GetRootNode(node), "Detector index error.");
230   index_2_function_[index] = func_node;
231   auto ret = func_node->GetNodes().back();
232   MS_EXCEPTION_IF_CHECK_FAIL(ret->isa<ir::ReturnNode>(), "Excepted Return Node, but got " + ret->GetNodeName() + ".");
233   return ret->cast<ir::ReturnNodePtr>()->GetReturn();
234 }
235 }  // namespace pijit
236 }  // namespace mindspore
237