1 /**
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "optimize_string_concat.h"
17
18 #include "compiler_logger.h"
19
20 #include "optimizer/analysis/alias_analysis.h"
21 #include "optimizer/analysis/bounds_analysis.h"
22 #include "optimizer/analysis/dominators_tree.h"
23 #include "optimizer/ir/analysis.h"
24 #include "optimizer/ir/datatype.h"
25 #include "optimizer/ir/graph.h"
26 #include "optimizer/ir/inst.h"
27
28 #include "optimizer/ir/runtime_interface.h"
29 #include "optimizer/optimizations/cleanup.h"
30 #include "optimizer/optimizations/string_builder_utils.h"
31
32 namespace ark::compiler {
33
OptimizeStringConcat(Graph * graph)34 OptimizeStringConcat::OptimizeStringConcat(Graph *graph) : Optimization(graph) {}
35
GetStringBuilderClassId(Graph * graph)36 RuntimeInterface::IdType GetStringBuilderClassId(Graph *graph)
37 {
38 auto runtime = graph->GetRuntime();
39 auto klass = runtime->GetStringBuilderClass();
40 return klass == nullptr ? 0 : runtime->GetClassIdWithinFile(graph->GetMethod(), klass);
41 }
42
RunImpl()43 bool OptimizeStringConcat::RunImpl()
44 {
45 bool isApplied = false;
46
47 if (GetStringBuilderClassId(GetGraph()) == 0) {
48 COMPILER_LOG(WARNING, OPTIMIZE_STRING_CONCAT) << "StringBuilder class not found";
49 return isApplied;
50 }
51
52 if (GetGraph()->IsAotMode()) {
53 // NOTE(mivanov): Creating StringBuilder.ctor calls in AOT mode not yet supported
54 return isApplied;
55 }
56
57 for (auto block : GetGraph()->GetBlocksRPO()) {
58 for (auto inst : block->Insts()) {
59 if (!IsMethodStringConcat(inst)) {
60 continue;
61 }
62
63 ReplaceStringConcatWithStringBuilderAppend(inst);
64 isApplied = true;
65 }
66 }
67
68 COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT) << "Optimize String.concat complete";
69
70 if (isApplied) {
71 GetGraph()->RunPass<compiler::Cleanup>();
72 }
73
74 return isApplied;
75 }
76
InvalidateAnalyses()77 void OptimizeStringConcat::InvalidateAnalyses()
78 {
79 GetGraph()->InvalidateAnalysis<BoundsAnalysis>();
80 GetGraph()->InvalidateAnalysis<AliasAnalysis>();
81 }
82
GetPhiConstantInput(Inst * phi)83 Inst *GetPhiConstantInput(Inst *phi)
84 {
85 ASSERT(phi->GetInputsCount() == 2U); // NOLINT(readability-magic-numbers)
86 ASSERT(phi->GetDataFlowInput(1)->IsPhi());
87 auto inputInst0 = phi->GetDataFlowInput(0);
88 if (inputInst0->IsConst()) {
89 return inputInst0;
90 }
91 if (inputInst0->IsPhi()) {
92 return GetPhiConstantInput(inputInst0);
93 }
94 UNREACHABLE();
95 }
96
GetArrayLength(Inst * newArray)97 Inst *GetArrayLength(Inst *newArray)
98 {
99 ASSERT(newArray->GetInputsCount() > 1);
100 auto inputInst1 = newArray->GetDataFlowInput(1);
101 if (inputInst1->IsConst()) {
102 return inputInst1;
103 }
104 if (inputInst1->IsPhi()) {
105 return GetPhiConstantInput(inputInst1);
106 }
107 UNREACHABLE();
108 }
109
CreateInstructionStringBuilderInstance(Graph * graph,uint32_t pc,SaveStateInst * saveState)110 Inst *CreateInstructionStringBuilderInstance(Graph *graph, uint32_t pc, SaveStateInst *saveState)
111 {
112 auto runtime = graph->GetRuntime();
113 auto method = graph->GetMethod();
114
115 auto classId = GetStringBuilderClassId(graph);
116 auto loadClass =
117 graph->CreateInstLoadAndInitClass(DataType::REFERENCE, pc, CopySaveState(graph, saveState),
118 TypeIdMixin {classId, method}, runtime->ResolveType(method, classId));
119 auto newObject = graph->CreateInstNewObject(DataType::REFERENCE, pc, loadClass, CopySaveState(graph, saveState),
120 TypeIdMixin {classId, method});
121
122 return newObject;
123 }
124
CreateStringBuilderAppendStringIntrinsic(Graph * graph,Inst * instance,Inst * arg,SaveStateInst * saveState)125 IntrinsicInst *CreateStringBuilderAppendStringIntrinsic(Graph *graph, Inst *instance, Inst *arg,
126 SaveStateInst *saveState)
127 {
128 auto appendIntrinsic = graph->CreateInstIntrinsic(graph->GetRuntime()->GetStringBuilderAppendStringIntrinsicId());
129 ASSERT(appendIntrinsic->RequireState());
130
131 appendIntrinsic->SetType(DataType::REFERENCE);
132 auto saveStateClone = CopySaveState(graph, saveState);
133 appendIntrinsic->SetInputs(
134 graph->GetAllocator(),
135 {{instance, instance->GetType()}, {arg, arg->GetType()}, {saveStateClone, saveStateClone->GetType()}});
136
137 return appendIntrinsic;
138 }
139
CreateStringBuilderToStringIntrinsic(Graph * graph,Inst * instance,SaveStateInst * saveState)140 IntrinsicInst *CreateStringBuilderToStringIntrinsic(Graph *graph, Inst *instance, SaveStateInst *saveState)
141 {
142 auto toStringCall = graph->CreateInstIntrinsic(graph->GetRuntime()->GetStringBuilderToStringIntrinsicId());
143 ASSERT(toStringCall->RequireState());
144
145 toStringCall->SetType(DataType::REFERENCE);
146 auto saveStateClone = CopySaveState(graph, saveState);
147 toStringCall->SetInputs(graph->GetAllocator(),
148 {{instance, instance->GetType()}, {saveStateClone, saveStateClone->GetType()}});
149
150 return toStringCall;
151 }
152
CreateStringBuilderDefaultConstructorCall(Graph * graph,Inst * instance,SaveStateInst * saveState)153 CallInst *CreateStringBuilderDefaultConstructorCall(Graph *graph, Inst *instance, SaveStateInst *saveState)
154 {
155 auto runtime = graph->GetRuntime();
156 auto method = runtime->GetStringBuilderDefaultConstructor();
157 auto methodId = runtime->GetMethodId(method);
158
159 auto ctorCall = graph->CreateInstCallStatic(DataType::VOID, instance->GetPc(), methodId, method);
160 ASSERT(ctorCall->RequireState());
161
162 auto saveStateClone = CopySaveState(graph, saveState);
163 ctorCall->SetInputs(graph->GetAllocator(),
164 {{instance, instance->GetType()}, {saveStateClone, saveStateClone->GetType()}});
165
166 return ctorCall;
167 }
168
CreateLoadArray(Graph * graph,Inst * array,Inst * index)169 Inst *CreateLoadArray(Graph *graph, Inst *array, Inst *index)
170 {
171 return graph->CreateInstLoadArray(DataType::REFERENCE, array->GetPc(), array, index);
172 }
173
CreateLoadArray(Graph * graph,Inst * array,uint64_t index)174 Inst *CreateLoadArray(Graph *graph, Inst *array, uint64_t index)
175 {
176 return CreateLoadArray(graph, array, graph->FindOrCreateConstant(index));
177 }
178
CreateLenArray(Graph * graph,Inst * newArray)179 Inst *CreateLenArray(Graph *graph, Inst *newArray)
180 {
181 return graph->CreateInstLenArray(DataType::INT32, newArray->GetPc(), newArray);
182 }
183
FixBrokenSaveStates(Inst * source,Inst * target)184 void OptimizeStringConcat::FixBrokenSaveStates(Inst *source, Inst *target)
185 {
186 if (source->IsMovableObject()) {
187 ssb_.SearchAndCreateMissingObjInSaveState(GetGraph(), source, target);
188 }
189 }
190
CreateAppendArgsIntrinsics(Inst * instance,Inst * args,uint64_t arrayLengthValue,SaveStateInst * saveState)191 void OptimizeStringConcat::CreateAppendArgsIntrinsics(Inst *instance, Inst *args, uint64_t arrayLengthValue,
192 SaveStateInst *saveState)
193 {
194 for (uint64_t index = 0; index < arrayLengthValue; ++index) {
195 auto arg = CreateLoadArray(GetGraph(), args, index);
196 auto appendIntrinsic = CreateStringBuilderAppendStringIntrinsic(GetGraph(), instance, arg, saveState);
197 InsertBeforeWithInputs(appendIntrinsic, saveState);
198
199 FixBrokenSaveStates(arg, appendIntrinsic);
200 FixBrokenSaveStates(instance, appendIntrinsic);
201
202 COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT)
203 << "Insert StringBuilder.append intrinsic (id=" << appendIntrinsic->GetId() << ")";
204 }
205 }
206
CreateSafePoint(Graph * graph,uint32_t pc,SaveStateInst * saveState)207 Inst *CreateSafePoint(Graph *graph, uint32_t pc, SaveStateInst *saveState)
208 {
209 auto safePoint =
210 graph->CreateInstSafePoint(pc, graph->GetMethod(), saveState->GetCallerInst(), saveState->GetInliningDepth());
211
212 for (size_t index = 0; index < saveState->GetInputsCount(); ++index) {
213 safePoint->AppendInput(saveState->GetInput(index));
214 safePoint->SetVirtualRegister(index, saveState->GetVirtualRegister(index));
215 }
216
217 return safePoint;
218 }
219
CreateAppendArgsLoop(Inst * instance,Inst * str,Inst * args,LengthMethodInst * arrayLength,Inst * concatCall)220 BasicBlock *OptimizeStringConcat::CreateAppendArgsLoop(Inst *instance, Inst *str, Inst *args,
221 LengthMethodInst *arrayLength, Inst *concatCall)
222 {
223 auto preHeader = concatCall->GetBasicBlock();
224 auto postExit = preHeader->SplitBlockAfterInstruction(concatCall, false);
225 auto saveState = concatCall->GetSaveState();
226
227 // Create loop CFG
228 auto header = GetGraph()->CreateEmptyBlock(preHeader);
229 auto backEdge = GetGraph()->CreateEmptyBlock(preHeader);
230 preHeader->AddSucc(header);
231 header->AddSucc(postExit);
232 header->AddSucc(backEdge);
233 backEdge->AddSucc(header);
234
235 // Declare loop variables
236 auto start = GetGraph()->FindOrCreateConstant(0);
237 auto stop = arrayLength;
238 auto step = GetGraph()->FindOrCreateConstant(1);
239
240 auto pc = instance->GetPc();
241
242 // Build header
243 auto induction = GetGraph()->CreateInstPhi(DataType::INT32, pc);
244 auto safePoint = CreateSafePoint(GetGraph(), pc, saveState);
245 auto compare =
246 GetGraph()->CreateInstCompare(DataType::BOOL, pc, stop, induction, DataType::INT32, ConditionCode::CC_LE);
247 auto ifImm = GetGraph()->CreateInstIfImm(DataType::BOOL, pc, compare, 0, DataType::BOOL, ConditionCode::CC_NE);
248 header->AppendPhi(induction);
249 header->AppendInsts({
250 safePoint,
251 compare,
252 ifImm,
253 });
254
255 // Build back edge
256 auto arg = CreateLoadArray(GetGraph(), args, induction);
257 auto appendIntrinsic = CreateStringBuilderAppendStringIntrinsic(GetGraph(), instance, arg, saveState);
258 auto add = GetGraph()->CreateInstAdd(DataType::INT32, pc, induction, step);
259 backEdge->AppendInsts({
260 arg,
261 appendIntrinsic->GetSaveState(),
262 appendIntrinsic,
263 add,
264 });
265
266 // Connect loop induction variable inputs
267 induction->AppendInput(start);
268 induction->AppendInput(add);
269
270 FixBrokenSaveStates(str, appendIntrinsic);
271 FixBrokenSaveStates(args, appendIntrinsic);
272 FixBrokenSaveStates(arg, appendIntrinsic);
273 FixBrokenSaveStates(instance, appendIntrinsic);
274
275 COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT)
276 << "Insert StringBuilder.append intrinsic (id=" << appendIntrinsic->GetId() << ")";
277
278 return postExit;
279 }
280
ReplaceStringConcatWithStringBuilderAppend(Inst * concatCall)281 void OptimizeStringConcat::ReplaceStringConcatWithStringBuilderAppend(Inst *concatCall)
282 {
283 // Input:
284 // let result = str.concat(...args)
285 //
286 // Output:
287 // let instance = new StringBuilder(str)
288 // instance.append(args[0])
289 // ...
290 // instance.append(args[args.length-1])
291 // let result = instance.toString()
292
293 ASSERT(concatCall->GetInputsCount() > 1);
294
295 auto str = concatCall->GetDataFlowInput(0);
296 auto args = concatCall->GetDataFlowInput(1);
297
298 auto instance = CreateInstructionStringBuilderInstance(GetGraph(), concatCall->GetPc(), concatCall->GetSaveState());
299 InsertBeforeWithInputs(instance, concatCall->GetSaveState());
300
301 auto ctorCall = CreateStringBuilderDefaultConstructorCall(GetGraph(), instance, concatCall->GetSaveState());
302 InsertBeforeWithSaveState(ctorCall, concatCall->GetSaveState());
303 auto appendArgIntrinsic =
304 CreateStringBuilderAppendStringIntrinsic(GetGraph(), instance, str, concatCall->GetSaveState());
305 InsertBeforeWithSaveState(appendArgIntrinsic, concatCall->GetSaveState());
306 FixBrokenSaveStates(instance, appendArgIntrinsic);
307 FixBrokenSaveStates(str, appendArgIntrinsic);
308
309 COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT)
310 << "Insert StringBuilder.append intrinsic (id=" << appendArgIntrinsic->GetId() << ")";
311
312 auto toStringCall = CreateStringBuilderToStringIntrinsic(GetGraph(), instance, concatCall->GetSaveState());
313
314 if (args->GetOpcode() == Opcode::NewArray) {
315 auto arrayLength = GetArrayLength(args);
316 CreateAppendArgsIntrinsics(instance, args, arrayLength->CastToConstant()->GetIntValue(),
317 concatCall->GetSaveState());
318 InsertBeforeWithSaveState(toStringCall, concatCall->GetSaveState());
319
320 COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT) << "Replace String.concat call (id=" << concatCall->GetId()
321 << ") with StringBuilder instance (id=" << instance->GetId() << ")";
322 } else {
323 auto arrayLength = CreateLenArray(GetGraph(), args);
324 concatCall->GetSaveState()->InsertBefore(arrayLength);
325 auto postExit = CreateAppendArgsLoop(instance, str, args, arrayLength->CastToLenArray(), concatCall);
326
327 postExit->PrependInst(toStringCall);
328 postExit->PrependInst(toStringCall->GetSaveState());
329
330 InvalidateBlocksOrderAnalyzes(GetGraph());
331 GetGraph()->InvalidateAnalysis<LoopAnalyzer>();
332
333 COMPILER_LOG(DEBUG, OPTIMIZE_STRING_CONCAT) << "Replace String.concat call (id=" << concatCall->GetId()
334 << ") with StringBuilder instance (id=" << instance->GetId() << ")";
335 }
336
337 FixBrokenSaveStates(instance, toStringCall);
338
339 concatCall->ReplaceUsers(toStringCall);
340
341 concatCall->ClearFlag(inst_flags::NO_DCE);
342 if (concatCall->GetInput(0).GetInst()->IsCheck()) {
343 concatCall->GetInput(0).GetInst()->ClearFlag(inst_flags::NO_DCE);
344 }
345 }
346
347 } // namespace ark::compiler
348