1 /**
2 * Copyright (c) 2024-2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "string_builder_utils.h"
17
18 namespace ark::compiler {
19
IsStringBuilderInstance(Inst * inst)20 bool IsStringBuilderInstance(Inst *inst)
21 {
22 if (inst->GetOpcode() != Opcode::NewObject) {
23 return false;
24 }
25
26 auto klass = GetObjectClass(inst->CastToNewObject());
27 if (klass == nullptr) {
28 return false;
29 }
30
31 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
32 return runtime->IsClassStringBuilder(klass);
33 }
34
IsMethodStringConcat(const Inst * inst)35 bool IsMethodStringConcat(const Inst *inst)
36 {
37 if (inst->GetOpcode() != Opcode::CallStatic && inst->GetOpcode() != Opcode::CallVirtual) {
38 return false;
39 }
40
41 auto call = static_cast<const CallInst *>(inst);
42 if (call->IsInlined()) {
43 return false;
44 }
45
46 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
47 return runtime->IsMethodStringConcat(call->GetCallMethod());
48 }
49
IsMethodStringBuilderConstructorWithStringArg(const Inst * inst)50 bool IsMethodStringBuilderConstructorWithStringArg(const Inst *inst)
51 {
52 if (inst->GetOpcode() != Opcode::CallStatic) {
53 return false;
54 }
55
56 auto call = inst->CastToCallStatic();
57 if (call->IsInlined()) {
58 return false;
59 }
60
61 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
62 return runtime->IsMethodStringBuilderConstructorWithStringArg(call->GetCallMethod());
63 }
64
IsMethodStringBuilderConstructorWithCharArrayArg(const Inst * inst)65 bool IsMethodStringBuilderConstructorWithCharArrayArg(const Inst *inst)
66 {
67 if (inst->GetOpcode() != Opcode::CallStatic) {
68 return false;
69 }
70
71 auto call = inst->CastToCallStatic();
72 if (call->IsInlined()) {
73 return false;
74 }
75
76 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
77 return runtime->IsMethodStringBuilderConstructorWithCharArrayArg(call->GetCallMethod());
78 }
79
IsStringBuilderToString(const Inst * inst)80 bool IsStringBuilderToString(const Inst *inst)
81 {
82 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
83 if (inst->GetOpcode() == Opcode::CallStatic || inst->GetOpcode() == Opcode::CallVirtual) {
84 auto callInst = static_cast<const CallInst *>(inst);
85 return !callInst->IsInlined() && runtime->IsMethodStringBuilderToString(callInst->GetCallMethod());
86 }
87 if (inst->IsIntrinsic()) {
88 auto intrinsic = inst->CastToIntrinsic();
89 return runtime->IsIntrinsicStringBuilderToString(intrinsic->GetIntrinsicId());
90 }
91 return false;
92 }
93
IsMethodStringBuilderDefaultConstructor(const Inst * inst)94 bool IsMethodStringBuilderDefaultConstructor(const Inst *inst)
95 {
96 if (inst->GetOpcode() != Opcode::CallStatic) {
97 return false;
98 }
99
100 auto call = inst->CastToCallStatic();
101 if (call->IsInlined()) {
102 return false;
103 }
104
105 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
106 return runtime->IsMethodStringBuilderDefaultConstructor(call->GetCallMethod());
107 }
108
IsMethodOrIntrinsicCall(const Inst * inst,const Inst * self)109 static bool IsMethodOrIntrinsicCall(const Inst *inst, const Inst *self)
110 {
111 const Inst *actualSelf = nullptr;
112 if (inst->GetOpcode() == Opcode::CallStatic) {
113 auto *call = inst->CastToCallStatic();
114 if (call->IsInlined()) {
115 return false;
116 }
117 actualSelf = call->GetObjectInst();
118 } else if (inst->IsIntrinsic()) {
119 auto *intrinsic = inst->CastToIntrinsic();
120 actualSelf = intrinsic->GetInput(0).GetInst();
121 } else {
122 return false;
123 }
124
125 if (self == nullptr) {
126 return true;
127 }
128
129 // Skip NullChecks
130 while (actualSelf->IsNullCheck()) {
131 actualSelf = actualSelf->CastToNullCheck()->GetInput(0).GetInst();
132 }
133 return actualSelf == self;
134 }
135
IsStringBuilderCtorCall(const Inst * inst,const Inst * self)136 bool IsStringBuilderCtorCall(const Inst *inst, const Inst *self)
137 {
138 if (!IsMethodOrIntrinsicCall(inst, self)) {
139 return false;
140 }
141
142 return IsMethodStringBuilderDefaultConstructor(inst) || IsMethodStringBuilderConstructorWithStringArg(inst) ||
143 IsMethodStringBuilderConstructorWithCharArrayArg(inst);
144 }
145
IsStringBuilderMethod(const Inst * inst,const Inst * self)146 bool IsStringBuilderMethod(const Inst *inst, const Inst *self)
147 {
148 if (!IsMethodOrIntrinsicCall(inst, self)) {
149 return false;
150 }
151
152 return IsStringBuilderCtorCall(inst, self) || IsStringBuilderAppend(inst) || IsStringBuilderToString(inst);
153 }
154
IsNullCheck(const Inst * inst,const Inst * self)155 bool IsNullCheck(const Inst *inst, const Inst *self)
156 {
157 return inst->IsNullCheck() && (self == nullptr || inst->CastToNullCheck()->GetInput(0) == self);
158 }
159
IsIntrinsicStringConcat(const Inst * inst)160 bool IsIntrinsicStringConcat(const Inst *inst)
161 {
162 if (!inst->IsIntrinsic()) {
163 return false;
164 }
165
166 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
167 return runtime->IsIntrinsicStringConcat(inst->CastToIntrinsic()->GetIntrinsicId());
168 }
169
InsertBeforeWithSaveState(Inst * inst,Inst * before)170 void InsertBeforeWithSaveState(Inst *inst, Inst *before)
171 {
172 ASSERT(before != nullptr);
173 if (inst->RequireState()) {
174 before->InsertBefore(inst->GetSaveState());
175 }
176 before->InsertBefore(inst);
177 }
178
InsertAfterWithSaveState(Inst * inst,Inst * after)179 void InsertAfterWithSaveState(Inst *inst, Inst *after)
180 {
181 after->InsertAfter(inst);
182 if (inst->RequireState()) {
183 after->InsertAfter(inst->GetSaveState());
184 }
185 }
186
InsertBeforeWithInputs(Inst * inst,Inst * before)187 void InsertBeforeWithInputs(Inst *inst, Inst *before)
188 {
189 for (auto &input : inst->GetInputs()) {
190 auto inputInst = input.GetInst();
191 if (inputInst->GetBasicBlock() == nullptr) {
192 InsertBeforeWithInputs(inputInst, before);
193 }
194 }
195
196 if (inst->GetBasicBlock() == nullptr) {
197 ASSERT(before != nullptr);
198 before->InsertBefore(inst);
199 }
200 }
201
HasInput(Inst * inst,const FindInputPredicate & predicate)202 bool HasInput(Inst *inst, const FindInputPredicate &predicate)
203 {
204 // Check if any instruction input satisfy predicate
205
206 auto found = std::find_if(inst->GetInputs().begin(), inst->GetInputs().end(), predicate);
207 return found != inst->GetInputs().end();
208 }
209
HasInputPhiRecursively(Inst * inst,Marker visited,const FindInputPredicate & predicate)210 bool HasInputPhiRecursively(Inst *inst, Marker visited, const FindInputPredicate &predicate)
211 {
212 // Check if any instruction input satisfy predicate
213 // All Phi-instruction inputs are checked recursively
214
215 if (HasInput(inst, predicate)) {
216 return true;
217 }
218
219 inst->SetMarker(visited);
220
221 for (auto &input : inst->GetInputs()) {
222 auto inputInst = input.GetInst();
223 if (!inputInst->IsPhi()) {
224 continue;
225 }
226 if (inputInst->IsMarked(visited)) {
227 continue;
228 }
229 if (HasInputPhiRecursively(inputInst, visited, predicate)) {
230 return true;
231 }
232 }
233
234 return false;
235 }
236
ResetInputMarkersRecursively(Inst * inst,Marker visited)237 void ResetInputMarkersRecursively(Inst *inst, Marker visited)
238 {
239 // Reset marker for an instruction and all it's inputs recursively
240
241 if (inst->IsMarked(visited)) {
242 inst->ResetMarker(visited);
243
244 for (auto &input : inst->GetInputs()) {
245 auto inputInst = input.GetInst();
246 if (inputInst->IsMarked(visited)) {
247 ResetInputMarkersRecursively(inputInst, visited);
248 }
249 }
250 }
251 }
252
HasUser(Inst * inst,const FindUserPredicate & predicate)253 bool HasUser(Inst *inst, const FindUserPredicate &predicate)
254 {
255 // Check if instruction is used in a context defined by predicate
256
257 auto found = std::find_if(inst->GetUsers().begin(), inst->GetUsers().end(), predicate);
258 return found != inst->GetUsers().end();
259 }
260
HasUserPhiRecursively(Inst * inst,Marker visited,const FindUserPredicate & predicate)261 bool HasUserPhiRecursively(Inst *inst, Marker visited, const FindUserPredicate &predicate)
262 {
263 // Check if instruction is used in a context defined by predicate
264 // All Phi-instruction users are checked recursively
265
266 if (HasUser(inst, predicate)) {
267 return true;
268 }
269
270 inst->SetMarker(visited);
271
272 for (auto &user : inst->GetUsers()) {
273 auto userInst = user.GetInst();
274 if (!userInst->IsPhi()) {
275 continue;
276 }
277 if (userInst->IsMarked(visited)) {
278 continue;
279 }
280 if (HasUserPhiRecursively(userInst, visited, predicate)) {
281 return true;
282 }
283 }
284
285 return false;
286 }
287
HasUserRecursively(Inst * inst,Marker visited,const FindUserPredicate & predicate)288 bool HasUserRecursively(Inst *inst, Marker visited, const FindUserPredicate &predicate)
289 {
290 // Check if instruction is used in a context defined by predicate
291 // All Check-instruction users are checked recursively
292
293 if (HasUser(inst, predicate)) {
294 return true;
295 }
296
297 inst->SetMarker(visited);
298
299 for (auto &user : inst->GetUsers()) {
300 auto userInst = user.GetInst();
301 if (!userInst->IsCheck()) {
302 continue;
303 }
304 if (userInst->IsMarked(visited)) {
305 continue;
306 }
307 if (HasUserRecursively(userInst, visited, predicate)) {
308 return true;
309 }
310 }
311
312 return false;
313 }
314
CountUsers(Inst * inst,const FindUserPredicate & predicate)315 size_t CountUsers(Inst *inst, const FindUserPredicate &predicate)
316 {
317 size_t count = 0;
318 for (auto &user : inst->GetUsers()) {
319 if (predicate(user)) {
320 ++count;
321 }
322
323 auto userInst = user.GetInst();
324 if (userInst->IsCheck()) {
325 count += CountUsers(userInst, predicate);
326 }
327 }
328
329 return count;
330 }
331
ResetUserMarkersRecursively(Inst * inst,Marker visited)332 void ResetUserMarkersRecursively(Inst *inst, Marker visited)
333 {
334 // Reset marker for an instruction and all it's users recursively
335
336 if (inst->IsMarked(visited)) {
337 inst->ResetMarker(visited);
338
339 for (auto &user : inst->GetUsers()) {
340 auto userInst = user.GetInst();
341 if (userInst->IsMarked(visited)) {
342 ResetUserMarkersRecursively(userInst, visited);
343 }
344 }
345 }
346 }
347
SkipSingleUserCheckInstruction(Inst * inst)348 Inst *SkipSingleUserCheckInstruction(Inst *inst)
349 {
350 if (inst->IsCheck() && inst->HasSingleUser()) {
351 inst = inst->GetUsers().Front().GetInst();
352 }
353 return inst;
354 }
355
IsIntrinsicStringBuilderAppendString(Inst * inst)356 bool IsIntrinsicStringBuilderAppendString(Inst *inst)
357 {
358 if (!inst->IsIntrinsic()) {
359 return false;
360 }
361
362 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
363 return runtime->IsIntrinsicStringBuilderAppendString(inst->CastToIntrinsic()->GetIntrinsicId());
364 }
365
IsUsedOutsideBasicBlock(Inst * inst,BasicBlock * bb)366 bool IsUsedOutsideBasicBlock(Inst *inst, BasicBlock *bb)
367 {
368 for (auto &user : inst->GetUsers()) {
369 auto userInst = user.GetInst();
370 if (userInst->IsCheck()) {
371 if (!userInst->HasUsers()) {
372 continue;
373 }
374 if (!userInst->HasSingleUser()) {
375 // In case of multi user check-instruction we assume it is used outside current basic block without
376 // actually testing it.
377 return true;
378 }
379 // In case of single user check-instruction we test its the only user.
380 userInst = userInst->GetUsers().Front().GetInst();
381 }
382 if (userInst->GetBasicBlock() != bb) {
383 return true;
384 }
385 }
386 return false;
387 }
388
FindFirstSaveState(BasicBlock * block)389 SaveStateInst *FindFirstSaveState(BasicBlock *block)
390 {
391 if (block->IsEmpty()) {
392 return nullptr;
393 }
394
395 for (auto inst : block->Insts()) {
396 if (inst->GetOpcode() == Opcode::SaveState) {
397 return inst->CastToSaveState();
398 }
399 }
400
401 return nullptr;
402 }
403
RemoveFromInstructionInputs(ArenaVector<InputDesc> & inputDescriptors,bool doMarkSaveStates)404 void RemoveFromInstructionInputs(ArenaVector<InputDesc> &inputDescriptors, bool doMarkSaveStates)
405 {
406 // Inputs must be walked in reverse order for removal
407 std::sort(inputDescriptors.begin(), inputDescriptors.end(),
408 [](auto inputDescX, auto inputDescY) { return inputDescX.second > inputDescY.second; });
409
410 for (auto inputDesc : inputDescriptors) {
411 auto inst = inputDesc.first;
412 auto index = inputDesc.second;
413 inst->RemoveInput(index);
414 if (inst->IsSaveState() && doMarkSaveStates) {
415 auto *saveState = static_cast<SaveStateInst *>(inst);
416 saveState->SetInputsWereDeleted();
417 #ifndef NDEBUG
418 if (!saveState->CanRemoveInputs()) {
419 saveState->SetInputsWereDeletedSafely(); // assuming this is safe
420 }
421 #endif
422 }
423 }
424 }
425
BreakStringBuilderAppendChains(BasicBlock * block)426 bool BreakStringBuilderAppendChains(BasicBlock *block)
427 {
428 // StringBuilder append-call returns 'this' (instance)
429 // Replace all users of append-call with instance itself to support chain calls
430 // like: sb.append(s0).append(s1)...
431 bool isApplied = false;
432 for (auto inst : block->Insts()) {
433 if (!IsStringBuilderAppend(inst) && !IsStringBuilderToString(inst)) {
434 continue;
435 }
436
437 auto instance = inst->GetDataFlowInput(0);
438 for (auto &user : instance->GetUsers()) {
439 auto userInst = SkipSingleUserCheckInstruction(user.GetInst());
440 if (IsStringBuilderAppend(userInst)) {
441 userInst->ReplaceUsers(instance);
442 isApplied = true;
443 }
444 }
445 }
446 return isApplied;
447 }
448
GetStoreArrayIndexConstant(Inst * storeArray)449 Inst *GetStoreArrayIndexConstant(Inst *storeArray)
450 {
451 ASSERT(storeArray->GetOpcode() == Opcode::StoreArray);
452 ASSERT(storeArray->GetInputsCount() > 1);
453
454 auto inputInst1 = storeArray->GetDataFlowInput(1U);
455 if (inputInst1->IsConst()) {
456 return inputInst1;
457 }
458
459 return nullptr;
460 }
461
FillArrayElement(Inst * inst,InstVector & arrayElements)462 bool FillArrayElement(Inst *inst, InstVector &arrayElements)
463 {
464 if (inst->GetOpcode() == Opcode::StoreArray) {
465 auto indexInst = GetStoreArrayIndexConstant(inst);
466 if (indexInst == nullptr) {
467 return false;
468 }
469
470 ASSERT(indexInst->IsConst());
471 auto indexValue = indexInst->CastToConstant()->GetIntValue();
472 if (arrayElements[indexValue] != nullptr) {
473 return false;
474 }
475
476 auto element = inst->GetDataFlowInput(2U);
477 arrayElements[indexValue] = element;
478 }
479 return true;
480 }
481
FillArrayElements(Inst * inst,InstVector & arrayElements)482 bool FillArrayElements(Inst *inst, InstVector &arrayElements)
483 {
484 for (auto &user : inst->GetUsers()) {
485 auto userInst = user.GetInst();
486 if (!FillArrayElement(userInst, arrayElements)) {
487 return false;
488 }
489 if (userInst->GetOpcode() == Opcode::NullCheck) {
490 if (!FillArrayElements(userInst, arrayElements)) {
491 return false;
492 }
493 }
494 }
495 return true;
496 }
497
GetArrayLengthConstant(Inst * newArray)498 Inst *GetArrayLengthConstant(Inst *newArray)
499 {
500 if (newArray->GetOpcode() != Opcode::NewArray) {
501 return nullptr;
502 }
503 ASSERT(newArray->GetInputsCount() > 1);
504
505 auto inputInst1 = newArray->GetDataFlowInput(1U);
506 if (inputInst1->IsConst()) {
507 return inputInst1;
508 }
509
510 return nullptr;
511 }
512
CollectArrayElements(Inst * newArray,InstVector & arrayElements)513 bool CollectArrayElements(Inst *newArray, InstVector &arrayElements)
514 {
515 /*
516 Collect instructions stored to a given array
517
518 This functions used to find all the arguments of the calls like:
519 str.concat(a, b, c)
520 IR builder generates the following IR for it:
521
522 bb_start:
523 v0 Constant 0x0
524 v1 Constant 0x1
525 v2 Constant 0x2
526 v3 Constant 0x3
527 bb1:
528 v9 NewArray class, v3, save_state
529 v10 StoreArray v9, v0, a
530 v11 StoreArray v9, v1, b
531 v12 StoreArray v9, v2, c
532 v20 CallStatic String::concat str, v9, save_state
533
534 Conditions:
535 - array size is constant (3 in the sample code above)
536 - every StoreArray instruction stores value by constant index (0, 1 and 2 in the sample code above)
537 - every element stored only once
538 - array filled completely
539
540 If any of the above is false, this functions returns false and clears array.
541 If all the above conditions true, this function returns true and fills array.
542 */
543
544 ASSERT(newArray->GetOpcode() == Opcode::NewArray);
545 arrayElements.clear();
546
547 auto lengthInst = GetArrayLengthConstant(newArray);
548 if (lengthInst == nullptr) {
549 return false;
550 }
551 ASSERT(lengthInst->IsConst());
552
553 auto length = lengthInst->CastToConstant()->GetIntValue();
554 arrayElements.resize(length);
555
556 if (!FillArrayElements(newArray, arrayElements)) {
557 arrayElements.clear();
558 return false;
559 }
560
561 // Check if array is filled completely
562 auto foundNull =
563 std::find_if(arrayElements.begin(), arrayElements.end(), [](auto &element) { return element == nullptr; });
564 if (foundNull != arrayElements.end()) {
565 arrayElements.clear();
566 return false;
567 }
568
569 return true;
570 }
571
CleanupStoreArrayInstructions(Inst * inst)572 void CleanupStoreArrayInstructions(Inst *inst)
573 {
574 for (auto &user : inst->GetUsers()) {
575 auto userInst = user.GetInst();
576 if (userInst->GetOpcode() == Opcode::StoreArray) {
577 userInst->ClearFlag(inst_flags::NO_DCE);
578 }
579 if (userInst->IsCheck()) {
580 userInst->ClearFlag(inst_flags::NO_DCE);
581 CleanupStoreArrayInstructions(userInst);
582 }
583 }
584 }
585
586 } // namespace ark::compiler
587