• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "string_builder_utils.h"
17 
18 namespace ark::compiler {
19 
IsStringBuilderInstance(Inst * inst)20 bool IsStringBuilderInstance(Inst *inst)
21 {
22     if (inst->GetOpcode() != Opcode::NewObject) {
23         return false;
24     }
25 
26     auto klass = GetObjectClass(inst->CastToNewObject());
27     if (klass == nullptr) {
28         return false;
29     }
30 
31     auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
32     return runtime->IsClassStringBuilder(klass);
33 }
34 
IsMethodStringConcat(Inst * inst)35 bool IsMethodStringConcat(Inst *inst)
36 {
37     if (inst->GetOpcode() != Opcode::CallStatic && inst->GetOpcode() != Opcode::CallVirtual) {
38         return false;
39     }
40 
41     auto call = static_cast<CallInst *>(inst);
42     if (call->IsInlined()) {
43         return false;
44     }
45 
46     auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
47     return runtime->IsMethodStringConcat(call->GetCallMethod());
48 }
49 
IsMethodStringBuilderConstructorWithStringArg(Inst * inst)50 bool IsMethodStringBuilderConstructorWithStringArg(Inst *inst)
51 {
52     if (inst->GetOpcode() != Opcode::CallStatic) {
53         return false;
54     }
55 
56     auto call = inst->CastToCallStatic();
57     if (call->IsInlined()) {
58         return false;
59     }
60 
61     auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
62     return runtime->IsMethodStringBuilderConstructorWithStringArg(call->GetCallMethod());
63 }
64 
IsMethodStringBuilderConstructorWithCharArrayArg(Inst * inst)65 bool IsMethodStringBuilderConstructorWithCharArrayArg(Inst *inst)
66 {
67     if (inst->GetOpcode() != Opcode::CallStatic) {
68         return false;
69     }
70 
71     auto call = inst->CastToCallStatic();
72     if (call->IsInlined()) {
73         return false;
74     }
75 
76     auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
77     return runtime->IsMethodStringBuilderConstructorWithCharArrayArg(call->GetCallMethod());
78 }
79 
IsStringBuilderToString(Inst * inst)80 bool IsStringBuilderToString(Inst *inst)
81 {
82     auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
83     if (inst->GetOpcode() == Opcode::CallStatic || inst->GetOpcode() == Opcode::CallVirtual) {
84         auto callInst = static_cast<CallInst *>(inst);
85         return !callInst->IsInlined() && runtime->IsMethodStringBuilderToString(callInst->GetCallMethod());
86     }
87     if (inst->IsIntrinsic()) {
88         auto intrinsic = inst->CastToIntrinsic();
89         return runtime->IsIntrinsicStringBuilderToString(intrinsic->GetIntrinsicId());
90     }
91     return false;
92 }
93 
IsMethodStringBuilderDefaultConstructor(Inst * inst)94 bool IsMethodStringBuilderDefaultConstructor(Inst *inst)
95 {
96     if (inst->GetOpcode() != Opcode::CallStatic) {
97         return false;
98     }
99 
100     auto call = inst->CastToCallStatic();
101     if (call->IsInlined()) {
102         return false;
103     }
104 
105     auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
106     return runtime->IsMethodStringBuilderDefaultConstructor(call->GetCallMethod());
107 }
108 
InsertBeforeWithSaveState(Inst * inst,Inst * before)109 void InsertBeforeWithSaveState(Inst *inst, Inst *before)
110 {
111     if (inst->RequireState()) {
112         before->InsertBefore(inst->GetSaveState());
113     }
114     before->InsertBefore(inst);
115 }
116 
InsertAfterWithSaveState(Inst * inst,Inst * after)117 void InsertAfterWithSaveState(Inst *inst, Inst *after)
118 {
119     after->InsertAfter(inst);
120     if (inst->RequireState()) {
121         after->InsertAfter(inst->GetSaveState());
122     }
123 }
124 
InsertBeforeWithInputs(Inst * inst,Inst * before)125 void InsertBeforeWithInputs(Inst *inst, Inst *before)
126 {
127     for (auto &input : inst->GetInputs()) {
128         auto inputInst = input.GetInst();
129         if (inputInst->GetBasicBlock() == nullptr) {
130             InsertBeforeWithInputs(inputInst, before);
131         }
132     }
133 
134     if (inst->GetBasicBlock() == nullptr) {
135         before->InsertBefore(inst);
136     }
137 }
138 
HasInput(Inst * inst,const FindInputPredicate & predicate)139 bool HasInput(Inst *inst, const FindInputPredicate &predicate)
140 {
141     // Check if any instruction input satisfy predicate
142 
143     auto found = std::find_if(inst->GetInputs().begin(), inst->GetInputs().end(), predicate);
144     return found != inst->GetInputs().end();
145 }
146 
HasInputPhiRecursively(Inst * inst,Marker visited,const FindInputPredicate & predicate)147 bool HasInputPhiRecursively(Inst *inst, Marker visited, const FindInputPredicate &predicate)
148 {
149     // Check if any instruction input satisfy predicate
150     // All Phi-instruction inputs are checked recursively
151 
152     if (HasInput(inst, predicate)) {
153         return true;
154     }
155 
156     inst->SetMarker(visited);
157 
158     for (auto &input : inst->GetInputs()) {
159         auto inputInst = input.GetInst();
160         if (!inputInst->IsPhi()) {
161             continue;
162         }
163         if (inputInst->IsMarked(visited)) {
164             continue;
165         }
166         if (HasInputPhiRecursively(inputInst, visited, predicate)) {
167             return true;
168         }
169     }
170 
171     return false;
172 }
173 
ResetInputMarkersRecursively(Inst * inst,Marker visited)174 void ResetInputMarkersRecursively(Inst *inst, Marker visited)
175 {
176     // Reset marker for an instruction and all it's inputs recursively
177 
178     if (inst->IsMarked(visited)) {
179         inst->ResetMarker(visited);
180 
181         for (auto &input : inst->GetInputs()) {
182             auto inputInst = input.GetInst();
183             if (inputInst->IsMarked(visited)) {
184                 ResetInputMarkersRecursively(inputInst, visited);
185             }
186         }
187     }
188 }
189 
HasUser(Inst * inst,const FindUserPredicate & predicate)190 bool HasUser(Inst *inst, const FindUserPredicate &predicate)
191 {
192     // Check if instruction is used in a context defined by predicate
193 
194     auto found = std::find_if(inst->GetUsers().begin(), inst->GetUsers().end(), predicate);
195     return found != inst->GetUsers().end();
196 }
197 
HasUserPhiRecursively(Inst * inst,Marker visited,const FindUserPredicate & predicate)198 bool HasUserPhiRecursively(Inst *inst, Marker visited, const FindUserPredicate &predicate)
199 {
200     // Check if instruction is used in a context defined by predicate
201     // All Phi-instruction users are checked recursively
202 
203     if (HasUser(inst, predicate)) {
204         return true;
205     }
206 
207     inst->SetMarker(visited);
208 
209     for (auto &user : inst->GetUsers()) {
210         auto userInst = user.GetInst();
211         if (!userInst->IsPhi()) {
212             continue;
213         }
214         if (userInst->IsMarked(visited)) {
215             continue;
216         }
217         if (HasUserPhiRecursively(userInst, visited, predicate)) {
218             return true;
219         }
220     }
221 
222     return false;
223 }
224 
HasUserRecursively(Inst * inst,Marker visited,const FindUserPredicate & predicate)225 bool HasUserRecursively(Inst *inst, Marker visited, const FindUserPredicate &predicate)
226 {
227     // Check if instruction is used in a context defined by predicate
228     // All Check-instruction users are checked recursively
229 
230     if (HasUser(inst, predicate)) {
231         return true;
232     }
233 
234     inst->SetMarker(visited);
235 
236     for (auto &user : inst->GetUsers()) {
237         auto userInst = user.GetInst();
238         if (!userInst->IsCheck()) {
239             continue;
240         }
241         if (userInst->IsMarked(visited)) {
242             continue;
243         }
244         if (HasUserRecursively(userInst, visited, predicate)) {
245             return true;
246         }
247     }
248 
249     return false;
250 }
251 
CountUsers(Inst * inst,const FindUserPredicate & predicate)252 size_t CountUsers(Inst *inst, const FindUserPredicate &predicate)
253 {
254     size_t count = 0;
255     for (auto &user : inst->GetUsers()) {
256         if (predicate(user)) {
257             ++count;
258         }
259 
260         auto userInst = user.GetInst();
261         if (userInst->IsCheck()) {
262             count += CountUsers(userInst, predicate);
263         }
264     }
265 
266     return count;
267 }
268 
ResetUserMarkersRecursively(Inst * inst,Marker visited)269 void ResetUserMarkersRecursively(Inst *inst, Marker visited)
270 {
271     // Reset marker for an instruction and all it's users recursively
272 
273     if (inst->IsMarked(visited)) {
274         inst->ResetMarker(visited);
275 
276         for (auto &user : inst->GetUsers()) {
277             auto userInst = user.GetInst();
278             if (userInst->IsMarked(visited)) {
279                 ResetUserMarkersRecursively(userInst, visited);
280             }
281         }
282     }
283 }
284 
SkipSingleUserCheckInstruction(Inst * inst)285 Inst *SkipSingleUserCheckInstruction(Inst *inst)
286 {
287     if (inst->IsCheck() && inst->HasSingleUser()) {
288         inst = inst->GetUsers().Front().GetInst();
289     }
290     return inst;
291 }
292 
IsIntrinsicStringBuilderAppendString(Inst * inst)293 bool IsIntrinsicStringBuilderAppendString(Inst *inst)
294 {
295     if (!inst->IsIntrinsic()) {
296         return false;
297     }
298 
299     auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
300     return runtime->IsIntrinsicStringBuilderAppendString(inst->CastToIntrinsic()->GetIntrinsicId());
301 }
302 
IsUsedOutsideBasicBlock(Inst * inst,BasicBlock * bb)303 bool IsUsedOutsideBasicBlock(Inst *inst, BasicBlock *bb)
304 {
305     for (auto &user : inst->GetUsers()) {
306         auto userInst = user.GetInst();
307         if (userInst->IsCheck()) {
308             if (!userInst->HasUsers()) {
309                 continue;
310             }
311             if (!userInst->HasSingleUser()) {
312                 // In case of multi user check-instruction we assume it is used outside current basic block without
313                 // actually testing it.
314                 return true;
315             }
316             // In case of single user check-instruction we test its the only user.
317             userInst = userInst->GetUsers().Front().GetInst();
318         }
319         if (userInst->GetBasicBlock() != bb) {
320             return true;
321         }
322     }
323     return false;
324 }
325 
FindFirstSaveState(BasicBlock * block)326 SaveStateInst *FindFirstSaveState(BasicBlock *block)
327 {
328     if (block->IsEmpty()) {
329         return nullptr;
330     }
331 
332     for (auto inst : block->Insts()) {
333         if (inst->GetOpcode() == Opcode::SaveState) {
334             return inst->CastToSaveState();
335         }
336     }
337 
338     return nullptr;
339 }
340 
RemoveFromInstructionInputs(ArenaVector<InputDesc> & inputDescriptors)341 void RemoveFromInstructionInputs(ArenaVector<InputDesc> &inputDescriptors)
342 {
343     // Inputs must be walked in reverse order for removal
344     std::sort(inputDescriptors.begin(), inputDescriptors.end(),
345               [](auto inputDescX, auto inputDescY) { return inputDescX.second > inputDescY.second; });
346 
347     for (auto inputDesc : inputDescriptors) {
348         auto inst = inputDesc.first;
349         auto index = inputDesc.second;
350         inst->RemoveInput(index);
351     }
352 }
353 
BreakStringBuilderAppendChains(BasicBlock * block)354 bool BreakStringBuilderAppendChains(BasicBlock *block)
355 {
356     // StringBuilder append-call returns 'this' (instance)
357     // Replace all users of append-call with instance itself to support chain calls
358     // like: sb.append(s0).append(s1)...
359     bool isApplied = false;
360     for (auto inst : block->Insts()) {
361         if (!IsStringBuilderAppend(inst) && !IsStringBuilderToString(inst)) {
362             continue;
363         }
364 
365         auto instance = inst->GetDataFlowInput(0);
366         for (auto &user : instance->GetUsers()) {
367             auto userInst = SkipSingleUserCheckInstruction(user.GetInst());
368             if (IsStringBuilderAppend(userInst)) {
369                 userInst->ReplaceUsers(instance);
370                 isApplied = true;
371             }
372         }
373     }
374     return isApplied;
375 }
376 
GetStoreArrayIndexConstant(Inst * storeArray)377 Inst *GetStoreArrayIndexConstant(Inst *storeArray)
378 {
379     ASSERT(storeArray->GetOpcode() == Opcode::StoreArray);
380     ASSERT(storeArray->GetInputsCount() > 1);
381 
382     auto inputInst1 = storeArray->GetDataFlowInput(1U);
383     if (inputInst1->IsConst()) {
384         return inputInst1;
385     }
386 
387     return nullptr;
388 }
389 
FillArrayElement(Inst * inst,InstVector & arrayElements)390 bool FillArrayElement(Inst *inst, InstVector &arrayElements)
391 {
392     if (inst->GetOpcode() == Opcode::StoreArray) {
393         auto indexInst = GetStoreArrayIndexConstant(inst);
394         if (indexInst == nullptr) {
395             return false;
396         }
397 
398         ASSERT(indexInst->IsConst());
399         auto indexValue = indexInst->CastToConstant()->GetIntValue();
400         if (arrayElements[indexValue] != nullptr) {
401             return false;
402         }
403 
404         auto element = inst->GetDataFlowInput(2U);
405         arrayElements[indexValue] = element;
406     }
407     return true;
408 }
409 
FillArrayElements(Inst * inst,InstVector & arrayElements)410 bool FillArrayElements(Inst *inst, InstVector &arrayElements)
411 {
412     for (auto &user : inst->GetUsers()) {
413         auto userInst = user.GetInst();
414         if (!FillArrayElement(userInst, arrayElements)) {
415             return false;
416         }
417         if (userInst->GetOpcode() == Opcode::NullCheck) {
418             if (!FillArrayElements(userInst, arrayElements)) {
419                 return false;
420             }
421         }
422     }
423     return true;
424 }
425 
GetArrayLengthConstant(Inst * newArray)426 Inst *GetArrayLengthConstant(Inst *newArray)
427 {
428     if (newArray->GetOpcode() != Opcode::NewArray) {
429         return nullptr;
430     }
431     ASSERT(newArray->GetInputsCount() > 1);
432 
433     auto inputInst1 = newArray->GetDataFlowInput(1U);
434     if (inputInst1->IsConst()) {
435         return inputInst1;
436     }
437 
438     return nullptr;
439 }
440 
CollectArrayElements(Inst * newArray,InstVector & arrayElements)441 bool CollectArrayElements(Inst *newArray, InstVector &arrayElements)
442 {
443     /*
444         Collect instructions stored to a given array
445 
446         This functions used to find all the arguments of the calls like:
447             str.concat(a, b, c)
448         IR builder generates the following IR for it:
449 
450         bb_start:
451             v0  Constant 0x0
452             v1  Constant 0x1
453             v2  Constant 0x2
454             v3  Constant 0x3
455         bb1:
456             v9  NewArray class, v3, save_state
457             v10 StoreArray v9, v0, a
458             v11 StoreArray v9, v1, b
459             v12 StoreArray v9, v2, c
460             v20 CallStatic String::concat str, v9, save_state
461 
462         Conditions:
463             - array size is constant (3 in the sample code above)
464             - every StoreArray instruction stores value by constant index (0, 1 and 2 in the sample code above)
465             - every element stored only once
466             - array filled completely
467 
468         If any of the above is false, this functions returns false and clears array.
469         If all the above conditions true, this function returns true and fills array.
470     */
471 
472     ASSERT(newArray->GetOpcode() == Opcode::NewArray);
473     arrayElements.clear();
474 
475     auto lengthInst = GetArrayLengthConstant(newArray);
476     if (lengthInst == nullptr) {
477         return false;
478     }
479     ASSERT(lengthInst->IsConst());
480 
481     auto length = lengthInst->CastToConstant()->GetIntValue();
482     arrayElements.resize(length);
483 
484     if (!FillArrayElements(newArray, arrayElements)) {
485         arrayElements.clear();
486         return false;
487     }
488 
489     // Check if array is filled completely
490     auto foundNull =
491         std::find_if(arrayElements.begin(), arrayElements.end(), [](auto &element) { return element == nullptr; });
492     if (foundNull != arrayElements.end()) {
493         arrayElements.clear();
494         return false;
495     }
496 
497     return true;
498 }
499 
CleanupStoreArrayInstructions(Inst * inst)500 void CleanupStoreArrayInstructions(Inst *inst)
501 {
502     for (auto &user : inst->GetUsers()) {
503         auto userInst = user.GetInst();
504         if (userInst->GetOpcode() == Opcode::StoreArray) {
505             userInst->ClearFlag(inst_flags::NO_DCE);
506         }
507         if (userInst->IsCheck()) {
508             userInst->ClearFlag(inst_flags::NO_DCE);
509             CleanupStoreArrayInstructions(userInst);
510         }
511     }
512 }
513 
514 }  // namespace ark::compiler
515