1 /**
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "string_builder_utils.h"
17
18 namespace ark::compiler {
19
IsStringBuilderInstance(Inst * inst)20 bool IsStringBuilderInstance(Inst *inst)
21 {
22 if (inst->GetOpcode() != Opcode::NewObject) {
23 return false;
24 }
25
26 auto klass = GetObjectClass(inst->CastToNewObject());
27 if (klass == nullptr) {
28 return false;
29 }
30
31 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
32 return runtime->IsClassStringBuilder(klass);
33 }
34
IsMethodStringConcat(Inst * inst)35 bool IsMethodStringConcat(Inst *inst)
36 {
37 if (inst->GetOpcode() != Opcode::CallStatic && inst->GetOpcode() != Opcode::CallVirtual) {
38 return false;
39 }
40
41 auto call = static_cast<CallInst *>(inst);
42 if (call->IsInlined()) {
43 return false;
44 }
45
46 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
47 return runtime->IsMethodStringConcat(call->GetCallMethod());
48 }
49
IsMethodStringBuilderConstructorWithStringArg(Inst * inst)50 bool IsMethodStringBuilderConstructorWithStringArg(Inst *inst)
51 {
52 if (inst->GetOpcode() != Opcode::CallStatic) {
53 return false;
54 }
55
56 auto call = inst->CastToCallStatic();
57 if (call->IsInlined()) {
58 return false;
59 }
60
61 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
62 return runtime->IsMethodStringBuilderConstructorWithStringArg(call->GetCallMethod());
63 }
64
IsMethodStringBuilderConstructorWithCharArrayArg(Inst * inst)65 bool IsMethodStringBuilderConstructorWithCharArrayArg(Inst *inst)
66 {
67 if (inst->GetOpcode() != Opcode::CallStatic) {
68 return false;
69 }
70
71 auto call = inst->CastToCallStatic();
72 if (call->IsInlined()) {
73 return false;
74 }
75
76 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
77 return runtime->IsMethodStringBuilderConstructorWithCharArrayArg(call->GetCallMethod());
78 }
79
IsStringBuilderToString(Inst * inst)80 bool IsStringBuilderToString(Inst *inst)
81 {
82 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
83 if (inst->GetOpcode() == Opcode::CallStatic || inst->GetOpcode() == Opcode::CallVirtual) {
84 auto callInst = static_cast<CallInst *>(inst);
85 return !callInst->IsInlined() && runtime->IsMethodStringBuilderToString(callInst->GetCallMethod());
86 }
87 if (inst->IsIntrinsic()) {
88 auto intrinsic = inst->CastToIntrinsic();
89 return runtime->IsIntrinsicStringBuilderToString(intrinsic->GetIntrinsicId());
90 }
91 return false;
92 }
93
IsMethodStringBuilderDefaultConstructor(Inst * inst)94 bool IsMethodStringBuilderDefaultConstructor(Inst *inst)
95 {
96 if (inst->GetOpcode() != Opcode::CallStatic) {
97 return false;
98 }
99
100 auto call = inst->CastToCallStatic();
101 if (call->IsInlined()) {
102 return false;
103 }
104
105 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
106 return runtime->IsMethodStringBuilderDefaultConstructor(call->GetCallMethod());
107 }
108
InsertBeforeWithSaveState(Inst * inst,Inst * before)109 void InsertBeforeWithSaveState(Inst *inst, Inst *before)
110 {
111 if (inst->RequireState()) {
112 before->InsertBefore(inst->GetSaveState());
113 }
114 before->InsertBefore(inst);
115 }
116
InsertAfterWithSaveState(Inst * inst,Inst * after)117 void InsertAfterWithSaveState(Inst *inst, Inst *after)
118 {
119 after->InsertAfter(inst);
120 if (inst->RequireState()) {
121 after->InsertAfter(inst->GetSaveState());
122 }
123 }
124
InsertBeforeWithInputs(Inst * inst,Inst * before)125 void InsertBeforeWithInputs(Inst *inst, Inst *before)
126 {
127 for (auto &input : inst->GetInputs()) {
128 auto inputInst = input.GetInst();
129 if (inputInst->GetBasicBlock() == nullptr) {
130 InsertBeforeWithInputs(inputInst, before);
131 }
132 }
133
134 if (inst->GetBasicBlock() == nullptr) {
135 before->InsertBefore(inst);
136 }
137 }
138
HasInput(Inst * inst,const FindInputPredicate & predicate)139 bool HasInput(Inst *inst, const FindInputPredicate &predicate)
140 {
141 // Check if any instruction input satisfy predicate
142
143 auto found = std::find_if(inst->GetInputs().begin(), inst->GetInputs().end(), predicate);
144 return found != inst->GetInputs().end();
145 }
146
HasInputPhiRecursively(Inst * inst,Marker visited,const FindInputPredicate & predicate)147 bool HasInputPhiRecursively(Inst *inst, Marker visited, const FindInputPredicate &predicate)
148 {
149 // Check if any instruction input satisfy predicate
150 // All Phi-instruction inputs are checked recursively
151
152 if (HasInput(inst, predicate)) {
153 return true;
154 }
155
156 inst->SetMarker(visited);
157
158 for (auto &input : inst->GetInputs()) {
159 auto inputInst = input.GetInst();
160 if (!inputInst->IsPhi()) {
161 continue;
162 }
163 if (inputInst->IsMarked(visited)) {
164 continue;
165 }
166 if (HasInputPhiRecursively(inputInst, visited, predicate)) {
167 return true;
168 }
169 }
170
171 return false;
172 }
173
ResetInputMarkersRecursively(Inst * inst,Marker visited)174 void ResetInputMarkersRecursively(Inst *inst, Marker visited)
175 {
176 // Reset marker for an instruction and all it's inputs recursively
177
178 if (inst->IsMarked(visited)) {
179 inst->ResetMarker(visited);
180
181 for (auto &input : inst->GetInputs()) {
182 auto inputInst = input.GetInst();
183 if (inputInst->IsMarked(visited)) {
184 ResetInputMarkersRecursively(inputInst, visited);
185 }
186 }
187 }
188 }
189
HasUser(Inst * inst,const FindUserPredicate & predicate)190 bool HasUser(Inst *inst, const FindUserPredicate &predicate)
191 {
192 // Check if instruction is used in a context defined by predicate
193
194 auto found = std::find_if(inst->GetUsers().begin(), inst->GetUsers().end(), predicate);
195 return found != inst->GetUsers().end();
196 }
197
HasUserPhiRecursively(Inst * inst,Marker visited,const FindUserPredicate & predicate)198 bool HasUserPhiRecursively(Inst *inst, Marker visited, const FindUserPredicate &predicate)
199 {
200 // Check if instruction is used in a context defined by predicate
201 // All Phi-instruction users are checked recursively
202
203 if (HasUser(inst, predicate)) {
204 return true;
205 }
206
207 inst->SetMarker(visited);
208
209 for (auto &user : inst->GetUsers()) {
210 auto userInst = user.GetInst();
211 if (!userInst->IsPhi()) {
212 continue;
213 }
214 if (userInst->IsMarked(visited)) {
215 continue;
216 }
217 if (HasUserPhiRecursively(userInst, visited, predicate)) {
218 return true;
219 }
220 }
221
222 return false;
223 }
224
HasUserRecursively(Inst * inst,Marker visited,const FindUserPredicate & predicate)225 bool HasUserRecursively(Inst *inst, Marker visited, const FindUserPredicate &predicate)
226 {
227 // Check if instruction is used in a context defined by predicate
228 // All Check-instruction users are checked recursively
229
230 if (HasUser(inst, predicate)) {
231 return true;
232 }
233
234 inst->SetMarker(visited);
235
236 for (auto &user : inst->GetUsers()) {
237 auto userInst = user.GetInst();
238 if (!userInst->IsCheck()) {
239 continue;
240 }
241 if (userInst->IsMarked(visited)) {
242 continue;
243 }
244 if (HasUserRecursively(userInst, visited, predicate)) {
245 return true;
246 }
247 }
248
249 return false;
250 }
251
CountUsers(Inst * inst,const FindUserPredicate & predicate)252 size_t CountUsers(Inst *inst, const FindUserPredicate &predicate)
253 {
254 size_t count = 0;
255 for (auto &user : inst->GetUsers()) {
256 if (predicate(user)) {
257 ++count;
258 }
259
260 auto userInst = user.GetInst();
261 if (userInst->IsCheck()) {
262 count += CountUsers(userInst, predicate);
263 }
264 }
265
266 return count;
267 }
268
ResetUserMarkersRecursively(Inst * inst,Marker visited)269 void ResetUserMarkersRecursively(Inst *inst, Marker visited)
270 {
271 // Reset marker for an instruction and all it's users recursively
272
273 if (inst->IsMarked(visited)) {
274 inst->ResetMarker(visited);
275
276 for (auto &user : inst->GetUsers()) {
277 auto userInst = user.GetInst();
278 if (userInst->IsMarked(visited)) {
279 ResetUserMarkersRecursively(userInst, visited);
280 }
281 }
282 }
283 }
284
SkipSingleUserCheckInstruction(Inst * inst)285 Inst *SkipSingleUserCheckInstruction(Inst *inst)
286 {
287 if (inst->IsCheck() && inst->HasSingleUser()) {
288 inst = inst->GetUsers().Front().GetInst();
289 }
290 return inst;
291 }
292
IsIntrinsicStringBuilderAppendString(Inst * inst)293 bool IsIntrinsicStringBuilderAppendString(Inst *inst)
294 {
295 if (!inst->IsIntrinsic()) {
296 return false;
297 }
298
299 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
300 return runtime->IsIntrinsicStringBuilderAppendString(inst->CastToIntrinsic()->GetIntrinsicId());
301 }
302
IsUsedOutsideBasicBlock(Inst * inst,BasicBlock * bb)303 bool IsUsedOutsideBasicBlock(Inst *inst, BasicBlock *bb)
304 {
305 for (auto &user : inst->GetUsers()) {
306 auto userInst = user.GetInst();
307 if (userInst->IsCheck()) {
308 if (!userInst->HasUsers()) {
309 continue;
310 }
311 if (!userInst->HasSingleUser()) {
312 // In case of multi user check-instruction we assume it is used outside current basic block without
313 // actually testing it.
314 return true;
315 }
316 // In case of single user check-instruction we test its the only user.
317 userInst = userInst->GetUsers().Front().GetInst();
318 }
319 if (userInst->GetBasicBlock() != bb) {
320 return true;
321 }
322 }
323 return false;
324 }
325
FindFirstSaveState(BasicBlock * block)326 SaveStateInst *FindFirstSaveState(BasicBlock *block)
327 {
328 if (block->IsEmpty()) {
329 return nullptr;
330 }
331
332 for (auto inst : block->Insts()) {
333 if (inst->GetOpcode() == Opcode::SaveState) {
334 return inst->CastToSaveState();
335 }
336 }
337
338 return nullptr;
339 }
340
RemoveFromInstructionInputs(ArenaVector<InputDesc> & inputDescriptors)341 void RemoveFromInstructionInputs(ArenaVector<InputDesc> &inputDescriptors)
342 {
343 // Inputs must be walked in reverse order for removal
344 std::sort(inputDescriptors.begin(), inputDescriptors.end(),
345 [](auto inputDescX, auto inputDescY) { return inputDescX.second > inputDescY.second; });
346
347 for (auto inputDesc : inputDescriptors) {
348 auto inst = inputDesc.first;
349 auto index = inputDesc.second;
350 inst->RemoveInput(index);
351 }
352 }
353
BreakStringBuilderAppendChains(BasicBlock * block)354 bool BreakStringBuilderAppendChains(BasicBlock *block)
355 {
356 // StringBuilder append-call returns 'this' (instance)
357 // Replace all users of append-call with instance itself to support chain calls
358 // like: sb.append(s0).append(s1)...
359 bool isApplied = false;
360 for (auto inst : block->Insts()) {
361 if (!IsStringBuilderAppend(inst) && !IsStringBuilderToString(inst)) {
362 continue;
363 }
364
365 auto instance = inst->GetDataFlowInput(0);
366 for (auto &user : instance->GetUsers()) {
367 auto userInst = SkipSingleUserCheckInstruction(user.GetInst());
368 if (IsStringBuilderAppend(userInst)) {
369 userInst->ReplaceUsers(instance);
370 isApplied = true;
371 }
372 }
373 }
374 return isApplied;
375 }
376
GetStoreArrayIndexConstant(Inst * storeArray)377 Inst *GetStoreArrayIndexConstant(Inst *storeArray)
378 {
379 ASSERT(storeArray->GetOpcode() == Opcode::StoreArray);
380 ASSERT(storeArray->GetInputsCount() > 1);
381
382 auto inputInst1 = storeArray->GetDataFlowInput(1U);
383 if (inputInst1->IsConst()) {
384 return inputInst1;
385 }
386
387 return nullptr;
388 }
389
FillArrayElement(Inst * inst,InstVector & arrayElements)390 bool FillArrayElement(Inst *inst, InstVector &arrayElements)
391 {
392 if (inst->GetOpcode() == Opcode::StoreArray) {
393 auto indexInst = GetStoreArrayIndexConstant(inst);
394 if (indexInst == nullptr) {
395 return false;
396 }
397
398 ASSERT(indexInst->IsConst());
399 auto indexValue = indexInst->CastToConstant()->GetIntValue();
400 if (arrayElements[indexValue] != nullptr) {
401 return false;
402 }
403
404 auto element = inst->GetDataFlowInput(2U);
405 arrayElements[indexValue] = element;
406 }
407 return true;
408 }
409
FillArrayElements(Inst * inst,InstVector & arrayElements)410 bool FillArrayElements(Inst *inst, InstVector &arrayElements)
411 {
412 for (auto &user : inst->GetUsers()) {
413 auto userInst = user.GetInst();
414 if (!FillArrayElement(userInst, arrayElements)) {
415 return false;
416 }
417 if (userInst->GetOpcode() == Opcode::NullCheck) {
418 if (!FillArrayElements(userInst, arrayElements)) {
419 return false;
420 }
421 }
422 }
423 return true;
424 }
425
GetArrayLengthConstant(Inst * newArray)426 Inst *GetArrayLengthConstant(Inst *newArray)
427 {
428 if (newArray->GetOpcode() != Opcode::NewArray) {
429 return nullptr;
430 }
431 ASSERT(newArray->GetInputsCount() > 1);
432
433 auto inputInst1 = newArray->GetDataFlowInput(1U);
434 if (inputInst1->IsConst()) {
435 return inputInst1;
436 }
437
438 return nullptr;
439 }
440
CollectArrayElements(Inst * newArray,InstVector & arrayElements)441 bool CollectArrayElements(Inst *newArray, InstVector &arrayElements)
442 {
443 /*
444 Collect instructions stored to a given array
445
446 This functions used to find all the arguments of the calls like:
447 str.concat(a, b, c)
448 IR builder generates the following IR for it:
449
450 bb_start:
451 v0 Constant 0x0
452 v1 Constant 0x1
453 v2 Constant 0x2
454 v3 Constant 0x3
455 bb1:
456 v9 NewArray class, v3, save_state
457 v10 StoreArray v9, v0, a
458 v11 StoreArray v9, v1, b
459 v12 StoreArray v9, v2, c
460 v20 CallStatic String::concat str, v9, save_state
461
462 Conditions:
463 - array size is constant (3 in the sample code above)
464 - every StoreArray instruction stores value by constant index (0, 1 and 2 in the sample code above)
465 - every element stored only once
466 - array filled completely
467
468 If any of the above is false, this functions returns false and clears array.
469 If all the above conditions true, this function returns true and fills array.
470 */
471
472 ASSERT(newArray->GetOpcode() == Opcode::NewArray);
473 arrayElements.clear();
474
475 auto lengthInst = GetArrayLengthConstant(newArray);
476 if (lengthInst == nullptr) {
477 return false;
478 }
479 ASSERT(lengthInst->IsConst());
480
481 auto length = lengthInst->CastToConstant()->GetIntValue();
482 arrayElements.resize(length);
483
484 if (!FillArrayElements(newArray, arrayElements)) {
485 arrayElements.clear();
486 return false;
487 }
488
489 // Check if array is filled completely
490 auto foundNull =
491 std::find_if(arrayElements.begin(), arrayElements.end(), [](auto &element) { return element == nullptr; });
492 if (foundNull != arrayElements.end()) {
493 arrayElements.clear();
494 return false;
495 }
496
497 return true;
498 }
499
CleanupStoreArrayInstructions(Inst * inst)500 void CleanupStoreArrayInstructions(Inst *inst)
501 {
502 for (auto &user : inst->GetUsers()) {
503 auto userInst = user.GetInst();
504 if (userInst->GetOpcode() == Opcode::StoreArray) {
505 userInst->ClearFlag(inst_flags::NO_DCE);
506 }
507 if (userInst->IsCheck()) {
508 userInst->ClearFlag(inst_flags::NO_DCE);
509 CleanupStoreArrayInstructions(userInst);
510 }
511 }
512 }
513
514 } // namespace ark::compiler
515