1 /**
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "optimize_string_concat.h"
17
18 #include "compiler_logger.h"
19
20 #include "optimizer/analysis/alias_analysis.h"
21 #include "optimizer/analysis/bounds_analysis.h"
22 #include "optimizer/analysis/dominators_tree.h"
23 #include "optimizer/ir/analysis.h"
24 #include "optimizer/ir/datatype.h"
25 #include "optimizer/ir/graph.h"
26 #include "optimizer/ir/inst.h"
27
28 #include "optimizer/ir/runtime_interface.h"
29 #include "optimizer/optimizations/cleanup.h"
30 #include "optimizer/optimizations/string_builder_utils.h"
31
32 namespace ark::compiler {
33
OptimizeStringConcat(Graph * graph)34 OptimizeStringConcat::OptimizeStringConcat(Graph *graph)
35 : Optimization(graph), arrayElements_ {graph->GetAllocator()->Adapter()}
36 {
37 }
38
GetStringBuilderClassId(Graph * graph)39 RuntimeInterface::IdType GetStringBuilderClassId(Graph *graph)
40 {
41 auto runtime = graph->GetRuntime();
42 auto klass = runtime->GetStringBuilderClass();
43 return klass == nullptr ? 0 : runtime->GetClassIdWithinFile(graph->GetMethod(), klass);
44 }
45
RunImpl()46 bool OptimizeStringConcat::RunImpl()
47 {
48 bool isApplied = false;
49
50 if (GetStringBuilderClassId(GetGraph()) == 0) {
51 COMPILER_LOG(WARNING, OPTIMIZE_STRING_CONCAT) << "StringBuilder class not found";
52 return isApplied;
53 }
54
55 if (GetGraph()->IsAotMode()) {
56 // NOTE(mivanov): Creating StringBuilder.ctor calls in AOT mode not yet supported
57 return isApplied;
58 }
59
60 for (auto block : GetGraph()->GetBlocksRPO()) {
61 for (auto inst : block->Insts()) {
62 if (!IsMethodStringConcat(inst)) {
63 continue;
64 }
65
66 ReplaceStringConcatWithStringBuilderAppend(inst);
67 isApplied = true;
68 }
69 }
70
71 COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT) << "Optimize String.concat complete";
72
73 if (isApplied) {
74 GetGraph()->RunPass<compiler::Cleanup>();
75 }
76
77 return isApplied;
78 }
79
InvalidateAnalyses()80 void OptimizeStringConcat::InvalidateAnalyses()
81 {
82 GetGraph()->InvalidateAnalysis<BoundsAnalysis>();
83 GetGraph()->InvalidateAnalysis<AliasAnalysis>();
84 }
85
CreateInstructionStringBuilderInstance(Graph * graph,uint32_t pc,SaveStateInst * saveState)86 Inst *CreateInstructionStringBuilderInstance(Graph *graph, uint32_t pc, SaveStateInst *saveState)
87 {
88 auto runtime = graph->GetRuntime();
89 auto method = graph->GetMethod();
90
91 auto classId = GetStringBuilderClassId(graph);
92 auto loadClass =
93 graph->CreateInstLoadAndInitClass(DataType::REFERENCE, pc, CopySaveState(graph, saveState),
94 TypeIdMixin {classId, method}, runtime->ResolveType(method, classId));
95 auto newObject = graph->CreateInstNewObject(DataType::REFERENCE, pc, loadClass, CopySaveState(graph, saveState),
96 TypeIdMixin {classId, method});
97
98 return newObject;
99 }
100
CreateStringBuilderAppendStringIntrinsic(Graph * graph,Inst * instance,Inst * arg,SaveStateInst * saveState)101 IntrinsicInst *CreateStringBuilderAppendStringIntrinsic(Graph *graph, Inst *instance, Inst *arg,
102 SaveStateInst *saveState)
103 {
104 auto appendIntrinsic = graph->CreateInstIntrinsic(graph->GetRuntime()->GetStringBuilderAppendStringsIntrinsicId(1));
105 ASSERT(appendIntrinsic->RequireState());
106
107 appendIntrinsic->SetType(DataType::REFERENCE);
108 auto saveStateClone = CopySaveState(graph, saveState);
109 appendIntrinsic->SetInputs(
110 graph->GetAllocator(),
111 {{instance, instance->GetType()}, {arg, arg->GetType()}, {saveStateClone, saveStateClone->GetType()}});
112
113 return appendIntrinsic;
114 }
115
CreateStringBuilderToStringIntrinsic(Graph * graph,Inst * instance,SaveStateInst * saveState)116 IntrinsicInst *CreateStringBuilderToStringIntrinsic(Graph *graph, Inst *instance, SaveStateInst *saveState)
117 {
118 auto toStringCall = graph->CreateInstIntrinsic(graph->GetRuntime()->GetStringBuilderToStringIntrinsicId());
119 ASSERT(toStringCall->RequireState());
120
121 toStringCall->SetType(DataType::REFERENCE);
122 auto saveStateClone = CopySaveState(graph, saveState);
123 toStringCall->SetInputs(graph->GetAllocator(),
124 {{instance, instance->GetType()}, {saveStateClone, saveStateClone->GetType()}});
125
126 return toStringCall;
127 }
128
CreateStringBuilderDefaultConstructorCall(Graph * graph,Inst * instance,SaveStateInst * saveState)129 CallInst *CreateStringBuilderDefaultConstructorCall(Graph *graph, Inst *instance, SaveStateInst *saveState)
130 {
131 auto runtime = graph->GetRuntime();
132 auto method = runtime->GetStringBuilderDefaultConstructor();
133 auto methodId = runtime->GetMethodId(method);
134
135 auto ctorCall = graph->CreateInstCallStatic(DataType::VOID, instance->GetPc(), methodId, method);
136 ASSERT(ctorCall->RequireState());
137
138 auto saveStateClone = CopySaveState(graph, saveState);
139 ctorCall->SetInputs(graph->GetAllocator(),
140 {{instance, instance->GetType()}, {saveStateClone, saveStateClone->GetType()}});
141
142 return ctorCall;
143 }
144
CreateLoadArray(Graph * graph,Inst * array,Inst * index)145 Inst *CreateLoadArray(Graph *graph, Inst *array, Inst *index)
146 {
147 return graph->CreateInstLoadArray(DataType::REFERENCE, array->GetPc(), array, index);
148 }
149
CreateLoadArray(Graph * graph,Inst * array,uint64_t index)150 Inst *CreateLoadArray(Graph *graph, Inst *array, uint64_t index)
151 {
152 return CreateLoadArray(graph, array, graph->FindOrCreateConstant(index));
153 }
154
CreateLenArray(Graph * graph,Inst * newArray)155 Inst *CreateLenArray(Graph *graph, Inst *newArray)
156 {
157 return graph->CreateInstLenArray(DataType::INT32, newArray->GetPc(), newArray);
158 }
159
FixBrokenSaveStates(Inst * source,Inst * target)160 void OptimizeStringConcat::FixBrokenSaveStates(Inst *source, Inst *target)
161 {
162 if (source->IsMovableObject()) {
163 ssb_.SearchAndCreateMissingObjInSaveState(GetGraph(), source, target);
164 }
165 }
166
CreateAppendArgsIntrinsic(Inst * instance,Inst * arg,SaveStateInst * saveState)167 void OptimizeStringConcat::CreateAppendArgsIntrinsic(Inst *instance, Inst *arg, SaveStateInst *saveState)
168 {
169 auto appendIntrinsic = CreateStringBuilderAppendStringIntrinsic(GetGraph(), instance, arg, saveState);
170 InsertBeforeWithInputs(appendIntrinsic, saveState);
171
172 FixBrokenSaveStates(arg, appendIntrinsic);
173 FixBrokenSaveStates(instance, appendIntrinsic);
174
175 COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT)
176 << "Insert StringBuilder.append intrinsic (id=" << appendIntrinsic->GetId() << ")";
177 }
178
CreateAppendArgsIntrinsics(Inst * instance,SaveStateInst * saveState)179 void OptimizeStringConcat::CreateAppendArgsIntrinsics(Inst *instance, SaveStateInst *saveState)
180 {
181 for (auto arg : arrayElements_) {
182 CreateAppendArgsIntrinsic(instance, arg, saveState);
183 }
184 }
185
CreateAppendArgsIntrinsics(Inst * instance,Inst * args,uint64_t arrayLengthValue,SaveStateInst * saveState)186 void OptimizeStringConcat::CreateAppendArgsIntrinsics(Inst *instance, Inst *args, uint64_t arrayLengthValue,
187 SaveStateInst *saveState)
188 {
189 for (uint64_t index = 0; index < arrayLengthValue; ++index) {
190 auto arg = CreateLoadArray(GetGraph(), args, index);
191 CreateAppendArgsIntrinsic(instance, arg, saveState);
192 }
193 }
194
CreateSafePoint(Graph * graph,uint32_t pc,SaveStateInst * saveState)195 Inst *CreateSafePoint(Graph *graph, uint32_t pc, SaveStateInst *saveState)
196 {
197 auto safePoint =
198 graph->CreateInstSafePoint(pc, graph->GetMethod(), saveState->GetCallerInst(), saveState->GetInliningDepth());
199
200 for (size_t index = 0; index < saveState->GetInputsCount(); ++index) {
201 safePoint->AppendInput(saveState->GetInput(index));
202 safePoint->SetVirtualRegister(index, saveState->GetVirtualRegister(index));
203 }
204
205 return safePoint;
206 }
207
CreateAppendArgsLoop(Inst * instance,Inst * str,Inst * args,LengthMethodInst * arrayLength,Inst * concatCall)208 BasicBlock *OptimizeStringConcat::CreateAppendArgsLoop(Inst *instance, Inst *str, Inst *args,
209 LengthMethodInst *arrayLength, Inst *concatCall)
210 {
211 auto preHeader = concatCall->GetBasicBlock();
212 auto postExit = preHeader->SplitBlockAfterInstruction(concatCall, false);
213 auto saveState = concatCall->GetSaveState();
214
215 // Create loop CFG
216 auto header = GetGraph()->CreateEmptyBlock(preHeader);
217 auto backEdge = GetGraph()->CreateEmptyBlock(preHeader);
218 preHeader->AddSucc(header);
219 header->AddSucc(postExit);
220 header->AddSucc(backEdge);
221 backEdge->AddSucc(header);
222
223 // Declare loop variables
224 auto start = GetGraph()->FindOrCreateConstant(0);
225 auto stop = arrayLength;
226 auto step = GetGraph()->FindOrCreateConstant(1);
227
228 auto pc = instance->GetPc();
229
230 // Build header
231 auto induction = GetGraph()->CreateInstPhi(DataType::INT32, pc);
232 auto safePoint = CreateSafePoint(GetGraph(), pc, saveState);
233 auto compare =
234 GetGraph()->CreateInstCompare(DataType::BOOL, pc, stop, induction, DataType::INT32, ConditionCode::CC_LE);
235 auto ifImm = GetGraph()->CreateInstIfImm(DataType::BOOL, pc, compare, 0, DataType::BOOL, ConditionCode::CC_NE);
236 header->AppendPhi(induction);
237 header->AppendInsts({
238 safePoint,
239 compare,
240 ifImm,
241 });
242
243 // Build back edge
244 auto arg = CreateLoadArray(GetGraph(), args, induction);
245 auto appendIntrinsic = CreateStringBuilderAppendStringIntrinsic(GetGraph(), instance, arg, saveState);
246 auto add = GetGraph()->CreateInstAdd(DataType::INT32, pc, induction, step);
247 backEdge->AppendInsts({
248 arg,
249 appendIntrinsic->GetSaveState(),
250 appendIntrinsic,
251 add,
252 });
253
254 // Connect loop induction variable inputs
255 induction->AppendInput(start);
256 induction->AppendInput(add);
257
258 FixBrokenSaveStates(str, appendIntrinsic);
259 FixBrokenSaveStates(args, appendIntrinsic);
260 FixBrokenSaveStates(arg, appendIntrinsic);
261 FixBrokenSaveStates(instance, appendIntrinsic);
262
263 COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT)
264 << "Insert StringBuilder.append intrinsic (id=" << appendIntrinsic->GetId() << ")";
265
266 return postExit;
267 }
268
HasStoreArrayUsersOnly(Inst * newArray,Inst * removable)269 bool OptimizeStringConcat::HasStoreArrayUsersOnly(Inst *newArray, Inst *removable)
270 {
271 ASSERT(newArray->GetOpcode() == Opcode::NewArray);
272
273 MarkerHolder visited {newArray->GetBasicBlock()->GetGraph()};
274 bool found = HasUserRecursively(newArray, visited.GetMarker(), [newArray, removable](auto &user) {
275 auto userInst = user.GetInst();
276 auto isSaveState = userInst->IsSaveState();
277 auto isCheck = userInst->IsCheck();
278 auto isStoreArray = userInst->GetOpcode() == Opcode::StoreArray && userInst->GetDataFlowInput(0) == newArray;
279 bool isRemovable = userInst == removable;
280 return !isSaveState && !isCheck && !isStoreArray && !isRemovable;
281 });
282
283 ResetUserMarkersRecursively(newArray, visited.GetMarker());
284 return !found;
285 }
286
ReplaceStringConcatWithStringBuilderAppend(Inst * concatCall)287 void OptimizeStringConcat::ReplaceStringConcatWithStringBuilderAppend(Inst *concatCall)
288 {
289 // Input:
290 // let result = str.concat(...args)
291 //
292 // Output:
293 // let instance = new StringBuilder(str)
294 // instance.append(args[0])
295 // ...
296 // instance.append(args[args.length-1])
297 // let result = instance.toString()
298
299 ASSERT(concatCall->GetInputsCount() > 1);
300
301 auto str = concatCall->GetDataFlowInput(0);
302 auto args = concatCall->GetDataFlowInput(1);
303
304 auto instance = CreateInstructionStringBuilderInstance(GetGraph(), concatCall->GetPc(), concatCall->GetSaveState());
305 InsertBeforeWithInputs(instance, concatCall->GetSaveState());
306
307 auto ctorCall = CreateStringBuilderDefaultConstructorCall(GetGraph(), instance, concatCall->GetSaveState());
308 InsertBeforeWithSaveState(ctorCall, concatCall->GetSaveState());
309 auto appendArgIntrinsic =
310 CreateStringBuilderAppendStringIntrinsic(GetGraph(), instance, str, concatCall->GetSaveState());
311 InsertBeforeWithSaveState(appendArgIntrinsic, concatCall->GetSaveState());
312 FixBrokenSaveStates(instance, appendArgIntrinsic);
313 FixBrokenSaveStates(str, appendArgIntrinsic);
314
315 COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT)
316 << "Insert StringBuilder.append intrinsic (id=" << appendArgIntrinsic->GetId() << ")";
317
318 auto toStringCall = CreateStringBuilderToStringIntrinsic(GetGraph(), instance, concatCall->GetSaveState());
319
320 auto arrayLength = GetArrayLengthConstant(args);
321 bool collected = args->GetOpcode() == Opcode::NewArray && CollectArrayElements(args, arrayElements_);
322 if (collected) {
323 CreateAppendArgsIntrinsics(instance, concatCall->GetSaveState());
324 InsertBeforeWithSaveState(toStringCall, concatCall->GetSaveState());
325 } else if (args->GetOpcode() == Opcode::NewArray && arrayLength != nullptr) {
326 CreateAppendArgsIntrinsics(instance, args, arrayLength->CastToConstant()->GetIntValue(),
327 concatCall->GetSaveState());
328 InsertBeforeWithSaveState(toStringCall, concatCall->GetSaveState());
329 } else {
330 arrayLength = CreateLenArray(GetGraph(), args);
331 concatCall->GetSaveState()->InsertBefore(arrayLength);
332 auto postExit = CreateAppendArgsLoop(instance, str, args, arrayLength->CastToLenArray(), concatCall);
333
334 postExit->PrependInst(toStringCall);
335 postExit->PrependInst(toStringCall->GetSaveState());
336
337 InvalidateBlocksOrderAnalyzes(GetGraph());
338 GetGraph()->InvalidateAnalysis<LoopAnalyzer>();
339 }
340 COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT) << "Replace String.concat call (id=" << concatCall->GetId()
341 << ") with StringBuilder instance (id=" << instance->GetId() << ")";
342
343 FixBrokenSaveStates(instance, toStringCall);
344
345 concatCall->ReplaceUsers(toStringCall);
346
347 concatCall->ClearFlag(inst_flags::NO_DCE);
348 if (concatCall->GetInput(0).GetInst()->IsCheck()) {
349 concatCall->GetInput(0).GetInst()->ClearFlag(inst_flags::NO_DCE);
350 }
351
352 if (collected && HasStoreArrayUsersOnly(args, concatCall)) {
353 CleanupStoreArrayInstructions(args);
354 args->ClearFlag(inst_flags::NO_DCE);
355 }
356 }
357
358 } // namespace ark::compiler
359