1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <gtest/gtest.h>
17
18 #include "libabckit/include/cpp/abckit_cpp.h"
19
20 #include "helpers/helpers.h"
21 #include "helpers/helpers_runtime.h"
22 #include "libabckit/src/logger.h"
23
24 namespace {
25
26 class ErrorHandler final : public abckit::IErrorHandler {
HandleError(abckit::Exception && e)27 void HandleError(abckit::Exception &&e) override
28 {
29 EXPECT_TRUE(false) << "Abckit exception raised: " << e.what();
30 }
31 };
32
TransformMethod(const abckit::core::Function & f,const std::function<void (const abckit::File *,abckit::core::Function)> & cb)33 inline void TransformMethod(const abckit::core::Function &f,
34 const std::function<void(const abckit::File *, abckit::core::Function)> &cb)
35 {
36 cb(f.GetFile(), f);
37 }
38
AddParamChecker(const abckit::core::Function & method)39 void AddParamChecker(const abckit::core::Function &method)
40 {
41 abckit::Graph graph = method.CreateGraph();
42
43 TransformMethod(method, [&]([[maybe_unused]] const abckit::File *file, const abckit::core::Function &method) {
44 abckit::BasicBlock startBB = graph.GetStartBb();
45 abckit::Instruction idx = startBB.GetLastInst();
46 abckit::Instruction arr = idx.GetPrev();
47
48 std::vector<abckit::BasicBlock> succBBs = startBB.GetSuccs();
49
50 std::string str = "length";
51
52 abckit::Instruction constant = graph.FindOrCreateConstantI32(-1);
53 abckit::Instruction arrLength = graph.DynIsa().CreateLdobjbyname(arr, str);
54
55 abckit::BasicBlock trueBB = succBBs[0];
56 startBB.EraseSuccBlock(ABCKIT_TRUE_SUCC_IDX);
57 abckit::BasicBlock falseBB = graph.CreateEmptyBb();
58 abckit::BasicBlock endBb = graph.GetEndBb();
59 falseBB.AppendSuccBlock(endBb);
60 falseBB.AddInstBack(graph.DynIsa().CreateReturn(constant));
61 abckit::BasicBlock ifBB = graph.CreateEmptyBb();
62 abckit::Instruction intrinsicGreatereq = graph.DynIsa().CreateGreatereq(arrLength, idx);
63 abckit::Instruction ifInst =
64 graph.DynIsa().CreateIf(intrinsicGreatereq, ABCKIT_ISA_API_DYNAMIC_CONDITION_CODE_CC_EQ);
65 ifBB.AddInstBack(arrLength);
66 ifBB.AddInstBack(intrinsicGreatereq);
67 ifBB.AddInstBack(ifInst);
68 startBB.AppendSuccBlock(ifBB);
69 ifBB.AppendSuccBlock(trueBB);
70 ifBB.AppendSuccBlock(falseBB);
71
72 method.SetGraph(graph);
73 });
74 }
75
76 struct MethodInfo {
77 std::string path;
78 std::string className;
79 std::string methodName;
80 };
81
GetImportDescriptor(const abckit::core::Module & module,MethodInfo & methodInfo)82 abckit::core::ImportDescriptor GetImportDescriptor(const abckit::core::Module &module, MethodInfo &methodInfo)
83 {
84 abckit::core::ImportDescriptor impDescriptor;
85 module.EnumerateImports([&](const abckit::core::ImportDescriptor &id) -> bool {
86 auto importName = id.GetName();
87 auto importedModule = id.GetImportedModule();
88 auto source = importedModule.GetName();
89 if (source != methodInfo.path) {
90 return false;
91 }
92 if (importName == methodInfo.className) {
93 impDescriptor = id;
94 return true;
95 }
96 return false;
97 });
98 return impDescriptor;
99 }
100
EnumerateModuleFunctions(const abckit::core::Module & mod,const std::function<bool (abckit::core::Function)> & cb)101 inline void EnumerateModuleFunctions(const abckit::core::Module &mod,
102 const std::function<bool(abckit::core::Function)> &cb)
103 {
104 // NOTE: currently we can only enumerate class methods and top level functions. need to update.
105 mod.EnumerateTopLevelFunctions(cb);
106 mod.EnumerateClasses([&](const abckit::core::Class &klass) -> bool {
107 klass.EnumerateMethods(cb);
108 return true;
109 });
110 }
111
EnumerateFunctionInsts(const abckit::core::Function & func,const std::function<void (abckit::Instruction)> & cb)112 inline void EnumerateFunctionInsts(const abckit::core::Function &func,
113 const std::function<void(abckit::Instruction)> &cb)
114 {
115 abckit::Graph graph = func.CreateGraph();
116 graph.EnumerateBasicBlocksRpo([&](const abckit::BasicBlock &bb) {
117 for (auto inst = bb.GetFirstInst(); inst; inst = inst.GetNext()) {
118 cb(inst);
119 }
120 return true;
121 });
122 }
123
GetMethodToModify(const abckit::core::Class & klass,MethodInfo & methodInfo)124 abckit::core::Function GetMethodToModify(const abckit::core::Class &klass, MethodInfo &methodInfo)
125 {
126 abckit::core::Function foundMethod;
127 klass.EnumerateMethods([&](const abckit::core::Function &method) -> bool {
128 auto name = method.GetName();
129 if (name == methodInfo.methodName) {
130 foundMethod = method;
131 }
132 return true;
133 });
134 return foundMethod;
135 }
136
GetSubclassMethod(const abckit::core::ImportDescriptor & id,const abckit::Instruction & inst,MethodInfo & methodInfo)137 abckit::core::Function GetSubclassMethod(const abckit::core::ImportDescriptor &id, const abckit::Instruction &inst,
138 MethodInfo &methodInfo)
139 {
140 abckit::core::Function foundMethod;
141 if (inst.GetGraph()->DynIsa().GetOpcode(inst) != ABCKIT_ISA_API_DYNAMIC_OPCODE_LDEXTERNALMODULEVAR) {
142 return foundMethod;
143 }
144
145 if (inst.GetGraph()->DynIsa().GetImportDescriptor(inst) != id) {
146 return foundMethod;
147 }
148
149 inst.VisitUsers([&](const abckit::Instruction &user) {
150 if (user.GetGraph()->DynIsa().GetOpcode(user) == ABCKIT_ISA_API_DYNAMIC_OPCODE_DEFINECLASSWITHBUFFER) {
151 auto method = user.GetFunction();
152 auto klass = method.GetParentClass();
153 foundMethod = GetMethodToModify(klass, methodInfo);
154 return false;
155 }
156 return true;
157 });
158
159 return foundMethod;
160 }
ModifyFunction(const abckit::core::Function & method,abckit::core::ImportDescriptor id,MethodInfo & methodInfo)161 void ModifyFunction(const abckit::core::Function &method, abckit::core::ImportDescriptor id, MethodInfo &methodInfo)
162 {
163 EnumerateFunctionInsts(method, [&](const abckit::Instruction &inst) {
164 auto subclassMethod = GetSubclassMethod(id, inst, methodInfo);
165 if (subclassMethod) {
166 AddParamChecker(subclassMethod);
167 }
168 });
169 }
170
171 } // namespace
172
173 namespace libabckit::test {
174
175 class AbckitScenarioCppTestClean : public ::testing::Test {};
176
177 // Test: test-kind=scenario, abc-kind=ArkTS1, category=positive, extension=cpp
TEST_F(AbckitScenarioCppTestClean,LibAbcKitTestDynamicParameterCheckClean)178 TEST_F(AbckitScenarioCppTestClean, LibAbcKitTestDynamicParameterCheckClean)
179 {
180 const std::string testSandboxPath = ABCKIT_ABC_DIR "clean_scenarios/cpp_api/dynamic/parameter_check/";
181 const std::string inputAbcPath = testSandboxPath + "parameter_check.abc";
182 const std::string outputAbcPath = testSandboxPath + "parameter_check_modified.abc";
183 const std::string entryPoint = "parameter_check";
184
185 abckit::File file(inputAbcPath, std::make_unique<ErrorHandler>());
186
187 auto output = helpers::ExecuteDynamicAbc(inputAbcPath, entryPoint);
188 EXPECT_TRUE(helpers::Match(output,
189 "str1\n"
190 "str2\n"
191 "str3\n"
192 "undefined\n"
193 "str3\n"
194 "str2\n"
195 "str4\n"
196 "undefined\n"));
197
198 MethodInfo methodInfo = {"modules/base", "Handler", "handle"};
199
200 file.EnumerateModules([&](const abckit::core::Module &mod) -> bool {
201 abckit::core::ImportDescriptor impDescriptor = GetImportDescriptor(mod, methodInfo);
202 if (impDescriptor) {
203 EnumerateModuleFunctions(mod, [&](const abckit::core::Function &method) -> bool {
204 ModifyFunction(method, impDescriptor, methodInfo);
205 return true;
206 });
207 }
208 return true;
209 });
210
211 file.WriteAbc(outputAbcPath);
212
213 output = helpers::ExecuteDynamicAbc(outputAbcPath, entryPoint);
214
215 EXPECT_TRUE(helpers::Match(output,
216 "str1\n"
217 "str2\n"
218 "str3\n"
219 "-1\n"
220 "str3\n"
221 "str2\n"
222 "str4\n"
223 "-1\n"));
224 }
225
226 } // namespace libabckit::test
227